Conversation
fef7530 to
17d88b1
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-03-20 20:13:55.127307+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-03-20 20:14:14.887149+00:00
@@ -623,11 +623,16 @@
for _fi, _fdim in enumerate(F):
_s = input_tensor.shape[_fdim]
if _s == DYNAMIC_DIM:
F_shape_values.append(
get_shape(
- ctx, target, source_ir, f"{name}_fshape_{_fdim}", input_tensor, _fdim
+ ctx,
+ target,
+ source_ir,
+ f"{name}_fshape_{_fdim}",
+ input_tensor,
+ _fdim,
)
)
else:
F_shape_values.append(_s)
_has_dynamic_f = any(isinstance(_s, TRTTensor) for _s in F_shape_values)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-03-20 20:13:55.159208+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-03-20 20:14:16.694415+00:00
@@ -357,11 +357,10 @@
)
result = trt_engine(source_tensor, indices_tensor, value_tensor)
torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4)
-
def test_kv_cache_dynamic_batch(self):
"""index_put with a dynamic free dimension (batch) — issue #4139.
Pattern: cache[..., idx, :] = values where dim-1 (batch) is dynamic
and dim-2 (cache/time) is the indexed static dimension.
@@ -420,12 +419,12 @@
use_explicit_typing=True,
min_block_size=1,
)
result = trt_mod(cache.clone(), values, idx)
- assert torch.allclose(result, torch_output, atol=1e-3, rtol=1e-3), (
- f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}"
- )
+ assert torch.allclose(
+ result, torch_output, atol=1e-3, rtol=1e-3
+ ), f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}"
if __name__ == "__main__":
run_tests()17d88b1 to
858bf17
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-03-20 22:25:06.608217+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-03-20 22:25:28.809312+00:00
@@ -392,12 +392,16 @@
)
trt_mod = torchtrt.dynamo.compile(
ep,
arg_inputs=[
torchtrt.Input(shape=(16,), dtype=torch.float32),
- torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32),
- torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32),
+ torchtrt.Input(
+ min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32
+ ),
+ torchtrt.Input(
+ min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32
+ ),
],
min_block_size=1,
)
result = trt_mod(src.clone(), values, idx)
assert torch.allclose(There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-03-23 18:18:24.127073+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-03-23 18:18:42.849840+00:00
@@ -772,13 +772,15 @@
ctx, target, source_ir, f"{name}_result_flat", src_flat, delta
)
# Rebuild the output shape (may contain dynamic dims)
out_shape = tuple(
- get_shape(ctx, target, source_ir, f"{name}_oshape_{i}", input_tensor, i)
- if input_tensor.shape[i] == DYNAMIC_DIM
- else input_tensor.shape[i]
+ (
+ get_shape(ctx, target, source_ir, f"{name}_oshape_{i}", input_tensor, i)
+ if input_tensor.shape[i] == DYNAMIC_DIM
+ else input_tensor.shape[i]
+ )
for i in range(rank)
)
return impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_result", result_flat, out_shape
)
@@ -871,11 +873,13 @@
F = [i for i in range(rank) if indices[i] is None] # Free dimensions
I = [i for i in range(rank) if indices[i] is not None] # Indexed dimensions
K = len(I)
# Determine the maximum size 'N' among the index tensors
if K > 0:
- index_shapes = [] # [tensor.shape[0] for tensor in indices if tensor is not None]
+ index_shapes = (
+ []
+ ) # [tensor.shape[0] for tensor in indices if tensor is not None]
for _ni, idx_tensor in enumerate(indices):
if idx_tensor is not None:
if idx_tensor.shape[0] != DYNAMIC_DIM:
index_shapes.append(idx_tensor.shape[0])
else:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-03-23 18:18:24.152947+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-03-23 18:18:45.151011+00:00
@@ -321,31 +321,45 @@
# duplicate positions correctly (mirrors test_index_put_accumulate_duplicate_indices).
param(
test_name="1d_duplicate_indices_accumulate",
source_tensor=torch.zeros([6], dtype=torch.float32),
indices_tensor=(torch.tensor([0, 0, 2, 2, 2], dtype=torch.int64),),
- value_tensor=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32),
+ value_tensor=torch.tensor(
+ [1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32
+ ),
accumulate=True,
),
param(
test_name="2d_indices_accumulate_True",
source_tensor=torch.zeros([5, 5], dtype=torch.float32),
- indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
+ indices_tensor=(
+ torch.tensor([0, 0], dtype=torch.int32),
+ torch.tensor([1, 1], dtype=torch.int32),
+ ),
value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32),
accumulate=True,
),
param(
test_name="3d_indices_accumulate_True",
source_tensor=torch.zeros([3, 3, 3], dtype=torch.float32),
- indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([2, 2], dtype=torch.int32)),
+ indices_tensor=(
+ torch.tensor([0, 0], dtype=torch.int32),
+ torch.tensor([1, 1], dtype=torch.int32),
+ torch.tensor([2, 2], dtype=torch.int32),
+ ),
value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32),
accumulate=True,
),
param(
test_name="4d_indices_accumulate_True",
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.float32),
- indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
+ indices_tensor=(
+ torch.tensor([0, 0], dtype=torch.int32),
+ torch.tensor([1, 1], dtype=torch.int32),
+ torch.tensor([0, 0], dtype=torch.int32),
+ torch.tensor([1, 1], dtype=torch.int32),
+ ),
value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32),
accumulate=True,
),
# Negative indices with accumulate (mirrors test_index_put_accumulate_large_tensor).
param(
@@ -357,29 +371,38 @@
),
# bfloat16 + duplicate indices: computation stays in bfloat16 (no forced fp32 cast).
param(
test_name="accumulate_bfloat16_duplicate",
source_tensor=torch.zeros([4, 4], dtype=torch.bfloat16),
- indices_tensor=(torch.tensor([0, 0, 2], dtype=torch.int64), torch.tensor([1, 1, 3], dtype=torch.int64)),
+ indices_tensor=(
+ torch.tensor([0, 0, 2], dtype=torch.int64),
+ torch.tensor([1, 1, 3], dtype=torch.int64),
+ ),
value_tensor=torch.tensor([1.0, 2.0, 4.0], dtype=torch.bfloat16),
accumulate=True,
),
# float16 + duplicate indices.
param(
test_name="accumulate_float16_duplicate",
source_tensor=torch.zeros([4, 4], dtype=torch.float16),
- indices_tensor=(torch.tensor([1, 1, 3], dtype=torch.int64), torch.tensor([0, 0, 2], dtype=torch.int64)),
+ indices_tensor=(
+ torch.tensor([1, 1, 3], dtype=torch.int64),
+ torch.tensor([0, 0, 2], dtype=torch.int64),
+ ),
value_tensor=torch.tensor([2.0, 3.0, 5.0], dtype=torch.float16),
accumulate=True,
),
# Partial broadcast: one index covers a single position on dim-1 while
# dim-0 has multiple positions — mirrors test_index_put_accumulate_expanded_values
# (t[tensor([0,1,2,3]), tensor([1])] += 1.0).
param(
test_name="accumulate_partial_dim1_broadcast",
source_tensor=torch.zeros([5, 2], dtype=torch.float32),
- indices_tensor=(torch.tensor([0, 1, 2, 3], dtype=torch.int64), torch.tensor([1], dtype=torch.int64)),
+ indices_tensor=(
+ torch.tensor([0, 1, 2, 3], dtype=torch.int64),
+ torch.tensor([1], dtype=torch.int64),
+ ),
value_tensor=torch.tensor([1.0], dtype=torch.float32),
accumulate=True,
),
]
)
@@ -512,12 +535,16 @@
)
trt_mod = torchtrt.dynamo.compile(
ep,
arg_inputs=[
torchtrt.Input(shape=(16,), dtype=torch.float32),
- torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32),
- torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32),
+ torchtrt.Input(
+ min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32
+ ),
+ torchtrt.Input(
+ min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32
+ ),
],
min_block_size=1,
)
result = trt_mod(src.clone(), values, idx)
assert torch.allclose(
@@ -587,11 +614,10 @@
result = trt_mod(cache.clone(), values, idx)
assert torch.allclose(
result, torch_output, atol=1e-3, rtol=1e-3
), f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}"
-
def test_accumulate_random_walk_duplicate_indices(self):
"""accumulate=True on 1-D input where indices are generated by a random walk
(many duplicates interleaved). Mirrors PyTorch's
test_index_put_accumulate_duplicate_indices, scaled to a TRT-friendly size.
@@ -664,13 +690,13 @@
torchtrt.Input(shape=(1, 2), dtype=torch.int64),
],
min_block_size=1,
)
result = trt_mod(src.clone(), values, i0, i1, i2)
- assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
- f"3D expand accumulate mismatch: max diff = {(result - torch_output).abs().max()}"
- )
+ assert torch.allclose(
+ result, torch_output, atol=1e-4, rtol=1e-4
+ ), f"3D expand accumulate mismatch: max diff = {(result - torch_output).abs().max()}"
def test_empty_index_no_op(self):
"""index_put with an empty index tensor is a no-op — output equals input.
Mirrors PyTorch's test_empty_index: x[empty_idx] = values leaves x unchanged.
@@ -697,13 +723,13 @@
torchtrt.Input(shape=(0,), dtype=torch.int64),
],
min_block_size=1,
)
result = trt_mod(src.clone(), values, idx)
- assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
- f"Empty-index no-op mismatch: {result} vs {torch_output}"
- )
+ assert torch.allclose(
+ result, torch_output, atol=1e-4, rtol=1e-4
+ ), f"Empty-index no-op mismatch: {result} vs {torch_output}"
def test_index_ind_dtype_int_vs_long(self):
"""int32 and int64 index tensors must produce identical results.
Mirrors PyTorch's test_index_ind_dtype.
@@ -725,27 +751,41 @@
assert torch.allclose(ref_long, ref_int), "CPU int32 vs int64 mismatch"
ep_long = torch.export.export(model, args=(src, values, idx_long))
ep_int = torch.export.export(model, args=(src, values, idx_int))
- trt_long = torchtrt.dynamo.compile(ep_long, arg_inputs=[
- torchtrt.Input(shape=(4, 4), dtype=torch.float32),
- torchtrt.Input(shape=(4,), dtype=torch.float32),
- torchtrt.Input(shape=(4,), dtype=torch.int64),
- ], min_block_size=1)
- trt_int = torchtrt.dynamo.compile(ep_int, arg_inputs=[
- torchtrt.Input(shape=(4, 4), dtype=torch.float32),
- torchtrt.Input(shape=(4,), dtype=torch.float32),
- torchtrt.Input(shape=(4,), dtype=torch.int32),
- ], min_block_size=1)
+ trt_long = torchtrt.dynamo.compile(
+ ep_long,
+ arg_inputs=[
+ torchtrt.Input(shape=(4, 4), dtype=torch.float32),
+ torchtrt.Input(shape=(4,), dtype=torch.float32),
+ torchtrt.Input(shape=(4,), dtype=torch.int64),
+ ],
+ min_block_size=1,
+ )
+ trt_int = torchtrt.dynamo.compile(
+ ep_int,
+ arg_inputs=[
+ torchtrt.Input(shape=(4, 4), dtype=torch.float32),
+ torchtrt.Input(shape=(4,), dtype=torch.float32),
+ torchtrt.Input(shape=(4,), dtype=torch.int32),
+ ],
+ min_block_size=1,
+ )
out_long = trt_long(src.clone(), values, idx_long)
out_int = trt_int(src.clone(), values, idx_int)
- assert torch.allclose(out_long, ref_long, atol=1e-4, rtol=1e-4), "TRT int64 mismatch"
- assert torch.allclose(out_int, ref_int, atol=1e-4, rtol=1e-4), "TRT int32 mismatch"
- assert torch.allclose(out_long, out_int, atol=1e-4, rtol=1e-4), "TRT int32 vs int64 inconsistency"
+ assert torch.allclose(
+ out_long, ref_long, atol=1e-4, rtol=1e-4
+ ), "TRT int64 mismatch"
+ assert torch.allclose(
+ out_int, ref_int, atol=1e-4, rtol=1e-4
+ ), "TRT int32 mismatch"
+ assert torch.allclose(
+ out_long, out_int, atol=1e-4, rtol=1e-4
+ ), "TRT int32 vs int64 inconsistency"
def test_accumulate_non_contiguous_source(self):
"""accumulate=True on a non-contiguous (sliced) source tensor.
Mirrors PyTorch's test_index_put_accumulate_non_contiguous.
@@ -780,26 +820,24 @@
torchtrt.Input(shape=(2,), dtype=torch.int64),
],
min_block_size=1,
)
result = trt_mod(src_slice.contiguous(), values, idx)
- assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
- f"Non-contiguous accumulate mismatch: max diff = {(result - torch_output).abs().max()}"
- )
+ assert torch.allclose(
+ result, torch_output, atol=1e-4, rtol=1e-4
+ ), f"Non-contiguous accumulate mismatch: max diff = {(result - torch_output).abs().max()}"
def test_accumulate_expanded_values_broadcast(self):
"""accumulate=True with value broadcasting — 0D scalar and 1D values
broadcast across unique indexed positions.
Mirrors PyTorch's test_index_put_accumulate_expanded_values (unique indices only).
"""
class AccumBroadcast(torch.nn.Module):
def forward(self, src, values, idx0, idx1):
- return torch.ops.aten.index_put.default(
- src, [idx0, idx1], values, True
- )
+ return torch.ops.aten.index_put.default(src, [idx0, idx1], values, True)
src = torch.zeros(5, 2, dtype=torch.float32, device="cuda")
idx0 = torch.tensor([0, 1, 2, 3], dtype=torch.int64, device="cuda")
idx1 = torch.tensor([0, 1, 0, 1], dtype=torch.int64, device="cuda")
values_1d = torch.tensor([1.0], dtype=torch.float32, device="cuda")
@@ -817,12 +855,12 @@
torchtrt.Input(shape=(4,), dtype=torch.int64),
],
min_block_size=1,
)
result = trt_mod(src.clone(), values_1d, idx0, idx1)
- assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
- f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"
- )
+ assert torch.allclose(
+ result, torch_output, atol=1e-4, rtol=1e-4
+ ), f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"
if __name__ == "__main__":
run_tests()8eb87a5 to
3402e30
Compare
3402e30 to
24f99f4
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-03-24 15:40:02.378996+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-03-24 15:40:21.128137+00:00
@@ -873,11 +873,13 @@
F = [i for i in range(rank) if indices[i] is None] # Free dimensions
I = [i for i in range(rank) if indices[i] is not None] # Indexed dimensions
K = len(I)
# Determine the maximum size 'N' among the index tensors
if K > 0:
- index_shapes = [] # [tensor.shape[0] for tensor in indices if tensor is not None]
+ index_shapes = (
+ []
+ ) # [tensor.shape[0] for tensor in indices if tensor is not None]
for _ni, idx_tensor in enumerate(indices):
if idx_tensor is not None:
if idx_tensor.shape[0] != DYNAMIC_DIM:
index_shapes.append(idx_tensor.shape[0])
else:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py 2026-03-24 15:40:02.394519+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py 2026-03-24 15:40:22.366634+00:00
@@ -158,11 +158,13 @@
out = torch.ops.aten.index.Tensor(x, indices)
return out
input = torch.randn(2, 2)
index0 = torch.tensor([True, False])
- self.run_test(TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True)
+ self.run_test(
+ TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True
+ )
def test_index_zero_index_three_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0, None]
@@ -170,11 +172,13 @@
return out
input = torch.randn(2, 2, 2)
index0 = torch.randint(0, 1, (1, 1))
index0 = index0.to(torch.int32)
- self.run_test(TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True)
+ self.run_test(
+ TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True
+ )
@unittest.skipIf(
ENABLED_FEATURES.tensorrt_rtx,
"Skipped on tensorrt_rtx due to nonzero not supported",
)
@@ -185,11 +189,13 @@
out = torch.ops.aten.index.Tensor(x, indices)
return out
input = torch.randn(2, 2, 2)
index0 = torch.tensor([True, False])
- self.run_test(TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True)
+ self.run_test(
+ TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True
+ )
class TestIndexDynamicConstantConverter(DispatchTestCase):
@parameterized.expand(
[24f99f4 to
503ca51
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_embed_engines.py 2026-03-24 18:29:04.478330+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_embed_engines.py 2026-03-24 18:29:24.473246+00:00
@@ -14,13 +14,11 @@
except (ImportError, RuntimeError):
HAS_TORCHVISION = False
class TestModelToEngineToModel(unittest.TestCase):
- @unittest.skipIf(
- not HAS_TORCHVISION, "torchvision not available"
- )
+ @unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
def test_resnet50(self):
--- /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_module_fallback.py 2026-03-24 18:29:04.478330+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_module_fallback.py 2026-03-24 18:29:24.488458+00:00
@@ -14,13 +14,11 @@
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
-@unittest.skipIf(
- not HAS_TORCHVISION, "torchvision not available"
-)
+@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
class TestModuleFallback(unittest.TestCase):
def test_fallback_resnet18(self):
self.model = models.resnet18(pretrained=True).eval().to("cuda")
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
compile_spec = {
--- /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_operator_fallback.py 2026-03-24 18:29:04.478330+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_operator_fallback.py 2026-03-24 18:29:24.518305+00:00
@@ -14,13 +14,11 @@
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
-@unittest.skipIf(
- not HAS_TORCHVISION, "torchvision not available"
-)
+@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
class TestFallbackModels(unittest.TestCase):
def test_fallback_resnet18(self):
self.model = models.resnet18(pretrained=True).eval().to("cuda")
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
compile_spec = {
--- /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_e2e_behavior.py 2026-03-24 18:29:04.478330+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_e2e_behavior.py 2026-03-24 18:29:24.569646+00:00
@@ -16,13 +16,11 @@
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
-@unittest.skipIf(
- not HAS_TORCHVISION, "torchvision not available"
-)
+@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
class TestInputTypeDefaultsFP32Model(unittest.TestCase):
def test_input_use_default_fp32(self):
self.model = models.resnet18(pretrained=True).eval().to("cuda")
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
@@ -68,13 +66,11 @@
class TestInputTypeDefaultsFP16Model(unittest.TestCase):
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
- @unittest.skipIf(
- not HAS_TORCHVISION, "torchvision not available"
- )
+ @unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
def test_input_use_default_fp16(self):
self.model = models.resnet18(pretrained=True).eval().to("cuda")
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
half_mod = torch.jit.script(self.model)
@@ -89,13 +85,11 @@
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
- @unittest.skipIf(
- not HAS_TORCHVISION, "torchvision not available"
- )
+ @unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
def test_input_use_default_fp16_without_fp16_enabled(self):
self.model = models.resnet18(pretrained=True).eval().to("cuda")
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
half_mod = torch.jit.script(self.model)
@@ -108,13 +102,11 @@
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
- @unittest.skipIf(
- not HAS_TORCHVISION, "torchvision not available"
- )
+ @unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
def test_input_respect_user_setting_fp16_weights_fp32_in(self):
self.model = models.resnet18(pretrained=True).eval().to("cuda")
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
half_mod = torch.jit.script(self.model)
@@ -130,13 +122,11 @@
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
- @unittest.skipIf(
- not HAS_TORCHVISION, "torchvision not available"
- )
+ @unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self):
self.model = models.resnet18(pretrained=True).eval().to("cuda")
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
half_mod = torch.jit.script(self.model)
--- /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_ts_backend.py 2026-03-24 18:29:04.478485+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_ts_backend.py 2026-03-24 18:29:24.694311+00:00
@@ -12,13 +12,11 @@
HAS_TORCHVISION = True
except (ImportError, RuntimeError):
HAS_TORCHVISION = False
-@unittest.skipIf(
- not HAS_TORCHVISION, "torchvision not available"
-)
+@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
class TestCompile(unittest.TestCase):
def test_compile_traced(self):
self.model = models.vgg16(pretrained=True).eval().to("cuda")
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.traced_model = torch.jit.trace(self.model, [self.input])
@@ -129,13 +127,11 @@
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
"aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
-@unittest.skipIf(
- not HAS_TORCHVISION, "torchvision not available"
-)
+@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
class TestCheckMethodOpSupport(unittest.TestCase):
def test_check_support(self):
module = models.alexnet(pretrained=True).eval().to("cuda")
self.module = torch.jit.trace(module, torch.ones((1, 3, 224, 224)).to("cuda"))
4352d78 to
fb2f692
Compare
fb2f692 to
314ec6c
Compare
zewenli98
left a comment
There was a problem hiding this comment.
The changes look good to me. My only concern for _index_put_scatter_add is the perf. Is it faster than leaving it in pytorch?
| # reduction mode. The current implementation (scatter into zeros + elementwise | ||
| # add) only gives correct results when every scattered index is unique. |
There was a problem hiding this comment.
The new implementation should have supported the cases with duplicated indices right?
There was a problem hiding this comment.
I believe so, will add a test for it
In rough tests it seemed competitive but i didnt have a good model on hand that uses it. If you have a reference model I can run it. |
There was a problem hiding this comment.
There are some changes that do not conform to C++ style guidelines:
diff --git a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.cpp b/tmp/changes.txt
index fbf81a9..e7a3620 100644
--- a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.cpp
+++ b/tmp/changes.txt
@@ -15,7 +15,6 @@ namespace impl {
// Helpers
// ---------------------------------------------------------------------------
-
// ---------------------------------------------------------------------------
// ScatterAddPlugin — construction
// ---------------------------------------------------------------------------
@@ -26,8 +25,7 @@ ScatterAddPlugin::ScatterAddPlugin() = default;
// IPluginV3
// ---------------------------------------------------------------------------
-nvinfer1::IPluginCapability* ScatterAddPlugin::getCapabilityInterface(
- nvinfer1::PluginCapabilityType type) noexcept {
+nvinfer1::IPluginCapability* ScatterAddPlugin::getCapabilityInterface(nvinfer1::PluginCapabilityType type) noexcept {
switch (type) {
case nvinfer1::PluginCapabilityType::kCORE:
return static_cast<nvinfer1::IPluginV3OneCore*>(this);
@@ -105,15 +103,13 @@ bool ScatterAddPlugin::supportsFormatCombination(
// Positions 1 through nbInputs-2 are index tensors: int32 or int64.
if (pos >= 1 && pos <= nbInputs - 2) {
- return desc.desc.type == nvinfer1::DataType::kINT32 ||
- desc.desc.type == nvinfer1::DataType::kINT64;
+ return desc.desc.type == nvinfer1::DataType::kINT32 || desc.desc.type == nvinfer1::DataType::kINT64;
}
// pos 0 (src), pos nbInputs-1 (values), pos nbInputs (output):
// float32 / float16 / bfloat16, all sharing the same type.
- const bool float_type = desc.desc.type == nvinfer1::DataType::kFLOAT ||
- desc.desc.type == nvinfer1::DataType::kHALF ||
- desc.desc.type == nvinfer1::DataType::kBF16;
+ const bool float_type = desc.desc.type == nvinfer1::DataType::kFLOAT || desc.desc.type == nvinfer1::DataType::kHALF ||
+ desc.desc.type == nvinfer1::DataType::kBF16;
if (!float_type) {
return false;
}
@@ -131,7 +127,7 @@ int32_t ScatterAddPlugin::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* /*out*/,
int32_t /*nbOutputs*/) noexcept {
dtype_ = in[0].desc.type;
- n_indices_ = nbInputs - 2; // exclude src and values
+ n_indices_ = nbInputs - 2; // exclude src and values
idx_dtypes_.resize(n_indices_);
for (int i = 0; i < n_indices_; ++i) {
idx_dtypes_[i] = in[1 + i].desc.type;
@@ -180,8 +176,7 @@ int32_t ScatterAddPlugin::enqueue(
const auto float_opts = at::TensorOptions().device(at::kCUDA).dtype(float_dtype);
at::Tensor src = at::from_blob(const_cast<void*>(inputs[0]), src_shape_, float_opts);
- at::Tensor val = at::from_blob(
- const_cast<void*>(inputs[n_indices_ + 1]), val_shape_, float_opts);
+ at::Tensor val = at::from_blob(const_cast<void*>(inputs[n_indices_ + 1]), val_shape_, float_opts);
// Build the indices list — one entry per index tensor, all cast to int64
// as required by ATen's index_put kernel.
@@ -190,8 +185,7 @@ int32_t ScatterAddPlugin::enqueue(
for (int i = 0; i < n_indices_; ++i) {
const at::ScalarType int_dtype = util::TRTDataTypeToScalarType(idx_dtypes_[i]);
const auto int_opts = at::TensorOptions().device(at::kCUDA).dtype(int_dtype);
- at::Tensor idx = at::from_blob(
- const_cast<void*>(inputs[1 + i]), idx_shapes_[i], int_opts);
+ at::Tensor idx = at::from_blob(const_cast<void*>(inputs[1 + i]), idx_shapes_[i], int_opts);
indices.push_back(std::optional<at::Tensor>(idx.to(torch::kLong)));
}
@@ -222,8 +216,7 @@ int32_t ScatterAddPlugin::enqueue(
return 0;
}
-nvinfer1::IPluginV3* ScatterAddPlugin::attachToContext(
- nvinfer1::IPluginResourceContext* /*context*/) noexcept {
+nvinfer1::IPluginV3* ScatterAddPlugin::attachToContext(nvinfer1::IPluginResourceContext* /*context*/) noexcept {
return clone();
}
diff --git a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.h b/tmp/changes.txt
index 0585f22..f882874 100644
--- a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.h
+++ b/tmp/changes.txt
@@ -110,15 +110,14 @@ class ScatterAddPlugin : public nvinfer1::IPluginV3,
// Captured at configurePlugin / onShapeChange time.
// Layout: input[0]=src, input[1..n_indices_]=indices, input[n_indices_+1]=values
nvinfer1::DataType dtype_{nvinfer1::DataType::kFLOAT};
- std::vector<nvinfer1::DataType> idx_dtypes_; // one per index input
+ std::vector<nvinfer1::DataType> idx_dtypes_; // one per index input
std::vector<int64_t> src_shape_;
std::vector<int64_t> val_shape_;
- std::vector<std::vector<int64_t>> idx_shapes_; // full shape of each index tensor
- int32_t n_indices_{0}; // number of index tensors
+ std::vector<std::vector<int64_t>> idx_shapes_; // full shape of each index tensor
+ int32_t n_indices_{0}; // number of index tensors
// Empty field collection — this plugin has no serializable attributes.
nvinfer1::PluginFieldCollection empty_fc_{0, nullptr};
-
};
// ---------------------------------------------------------------------------
ERROR: Some files do not conform to style guidelinesThere was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-04-02 22:43:56.836073+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-04-02 22:44:20.017903+00:00
@@ -588,11 +588,10 @@
set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir)
out = gather_layer.get_output(0)
return out
-
def index_put_scatter_add_plugin(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
@@ -608,13 +607,13 @@
Supports any number of non-None index tensors (N >= 1). The plugin
inputs are laid out as: [src, idx_0, ..., idx_{N-1}, values].
"""
non_none_indices = [x for x in input_indices if x is not None]
- assert len(non_none_indices) >= 1, (
- "ScatterAdd plugin requires at least one non-None index tensor"
- )
+ assert (
+ len(non_none_indices) >= 1
+ ), "ScatterAdd plugin requires at least one non-None index tensor"
# Plugin supports float32/float16/bfloat16; cast other types through float32.
_supported_float_dtypes = (trt.float32, trt.float16, trt.bfloat16)
original_dtype = input_tensor.dtype
if original_dtype not in _supported_float_dtypes:
@@ -635,18 +634,16 @@
if not isinstance(values, ITensor):
values = get_trt_tensor(ctx, values, f"{name}_values", min_rank=1)
# Values must match src dtype after any cast above.
if values.dtype != input_tensor.dtype:
- values = cast_trt_tensor(
- ctx, values, input_tensor.dtype, f"{name}_values_cast"
- )
+ values = cast_trt_tensor(ctx, values, input_tensor.dtype, f"{name}_values_cast")
creator = trt.get_plugin_registry().get_creator("ScatterAdd", "1", "torch_tensorrt")
- assert creator is not None, (
- "ScatterAdd plugin creator not found — is torch_tensorrt_runtime loaded?"
- )
+ assert (
+ creator is not None
+ ), "ScatterAdd plugin creator not found — is torch_tensorrt_runtime loaded?"
pfc = trt.PluginFieldCollection([])
plugin = creator.create_plugin("ScatterAdd", pfc, trt.TensorRTPhase.BUILD)
assert plugin is not None, "Failed to create ScatterAdd plugin instance"
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2026-04-02 22:43:56.833500+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2026-04-02 22:44:21.995858+00:00
@@ -1110,11 +1110,12 @@
ScatterAdd plugin is present in the TRT plugin registry (registered
at library init time when torch_tensorrt_runtime is loaded).
Supports any number of non-None index tensors.
"""
if not (
- index_dtype_validator(node, settings) and index_nonbool_validator(node, settings)
+ index_dtype_validator(node, settings)
+ and index_nonbool_validator(node, settings)
):
return False
if not args_bounds_check(node.args, 3, False):
return False
return _scatter_add_plugin_available()
@@ -1176,12 +1177,10 @@
name,
args[0],
args[1],
args[2],
)
-
-
@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-04-02 22:43:56.864073+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-04-02 22:44:23.373273+00:00
@@ -859,11 +859,10 @@
result = trt_mod(src.clone(), values_1d, idx0, idx1)
assert torch.allclose(
result, torch_output, atol=1e-4, rtol=1e-4
), f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"
-
# ------------------------------------------------------------------
# Duplicate-index tests for realistic use-case models
# These mirror the scenarios in experiments/bench_index_put_scatter_add.py
# and verify that _index_put_scatter_add correctly accumulates into
# duplicate positions when index_put is embedded in a larger graph.
@@ -970,11 +969,11 @@
self.conv = torch.nn.Conv1d(C, 8, kernel_size=3, padding=1)
self.head = torch.nn.Linear(8, 4, bias=False)
def forward(self, signal, hist):
# signal: (1, C, L), hist: (n_bins, 8)
- feat = self.conv(signal).squeeze(0).T # (L, 8)
+ feat = self.conv(signal).squeeze(0).T # (L, 8)
hist = hist.index_put((get_bin_ids(),), feat, accumulate=True)
return self.head(hist.mean(dim=0))
signal = torch.randn(1, C, L)
hist = torch.zeros(n_bins, 8)c6e3a52 to
ac26d06
Compare
There was a problem hiding this comment.
There are some changes that do not conform to C++ style guidelines:
diff --git a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.cpp b/tmp/changes.txt
index 038667c..ac7c321 100644
--- a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.cpp
+++ b/tmp/changes.txt
@@ -11,11 +11,9 @@ namespace core {
namespace plugins {
namespace impl {
-
ScatterAddPlugin::ScatterAddPlugin() = default;
-nvinfer1::IPluginCapability* ScatterAddPlugin::getCapabilityInterface(
- nvinfer1::PluginCapabilityType type) noexcept {
+nvinfer1::IPluginCapability* ScatterAddPlugin::getCapabilityInterface(nvinfer1::PluginCapabilityType type) noexcept {
switch (type) {
case nvinfer1::PluginCapabilityType::kCORE:
return static_cast<nvinfer1::IPluginV3OneCore*>(this);
@@ -93,15 +91,13 @@ bool ScatterAddPlugin::supportsFormatCombination(
// Positions 1 through nbInputs-2 are index tensors: int32 or int64.
if (pos >= 1 && pos <= nbInputs - 2) {
- return desc.desc.type == nvinfer1::DataType::kINT32 ||
- desc.desc.type == nvinfer1::DataType::kINT64;
+ return desc.desc.type == nvinfer1::DataType::kINT32 || desc.desc.type == nvinfer1::DataType::kINT64;
}
// pos 0 (src), pos nbInputs-1 (values), pos nbInputs (output):
// float32 / float16 / bfloat16, all sharing the same type.
- const bool float_type = desc.desc.type == nvinfer1::DataType::kFLOAT ||
- desc.desc.type == nvinfer1::DataType::kHALF ||
- desc.desc.type == nvinfer1::DataType::kBF16;
+ const bool float_type = desc.desc.type == nvinfer1::DataType::kFLOAT || desc.desc.type == nvinfer1::DataType::kHALF ||
+ desc.desc.type == nvinfer1::DataType::kBF16;
if (!float_type) {
return false;
}
@@ -119,7 +115,7 @@ int32_t ScatterAddPlugin::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* /*out*/,
int32_t /*nbOutputs*/) noexcept {
dtype_ = in[0].desc.type;
- n_indices_ = nbInputs - 2; // exclude src and values
+ n_indices_ = nbInputs - 2; // exclude src and values
idx_dtypes_.resize(n_indices_);
for (int i = 0; i < n_indices_; ++i) {
idx_dtypes_[i] = in[1 + i].desc.type;
@@ -168,8 +164,7 @@ int32_t ScatterAddPlugin::enqueue(
const auto float_opts = at::TensorOptions().device(at::kCUDA).dtype(float_dtype);
at::Tensor src = at::from_blob(const_cast<void*>(inputs[0]), src_shape_, float_opts);
- at::Tensor val = at::from_blob(
- const_cast<void*>(inputs[n_indices_ + 1]), val_shape_, float_opts);
+ at::Tensor val = at::from_blob(const_cast<void*>(inputs[n_indices_ + 1]), val_shape_, float_opts);
// Build the indices list — one entry per index tensor, all cast to int64
// as required by ATen's index_put kernel.
@@ -178,8 +173,7 @@ int32_t ScatterAddPlugin::enqueue(
for (int i = 0; i < n_indices_; ++i) {
const at::ScalarType int_dtype = util::TRTDataTypeToScalarType(idx_dtypes_[i]);
const auto int_opts = at::TensorOptions().device(at::kCUDA).dtype(int_dtype);
- at::Tensor idx = at::from_blob(
- const_cast<void*>(inputs[1 + i]), idx_shapes_[i], int_opts);
+ at::Tensor idx = at::from_blob(const_cast<void*>(inputs[1 + i]), idx_shapes_[i], int_opts);
indices.push_back(std::optional<at::Tensor>(idx.to(torch::kLong)));
}
@@ -210,8 +204,7 @@ int32_t ScatterAddPlugin::enqueue(
return 0;
}
-nvinfer1::IPluginV3* ScatterAddPlugin::attachToContext(
- nvinfer1::IPluginResourceContext* /*context*/) noexcept {
+nvinfer1::IPluginV3* ScatterAddPlugin::attachToContext(nvinfer1::IPluginResourceContext* /*context*/) noexcept {
return clone();
}
diff --git a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.h b/tmp/changes.txt
index 0585f22..f882874 100644
--- a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.h
+++ b/tmp/changes.txt
@@ -110,15 +110,14 @@ class ScatterAddPlugin : public nvinfer1::IPluginV3,
// Captured at configurePlugin / onShapeChange time.
// Layout: input[0]=src, input[1..n_indices_]=indices, input[n_indices_+1]=values
nvinfer1::DataType dtype_{nvinfer1::DataType::kFLOAT};
- std::vector<nvinfer1::DataType> idx_dtypes_; // one per index input
+ std::vector<nvinfer1::DataType> idx_dtypes_; // one per index input
std::vector<int64_t> src_shape_;
std::vector<int64_t> val_shape_;
- std::vector<std::vector<int64_t>> idx_shapes_; // full shape of each index tensor
- int32_t n_indices_{0}; // number of index tensors
+ std::vector<std::vector<int64_t>> idx_shapes_; // full shape of each index tensor
+ int32_t n_indices_{0}; // number of index tensors
// Empty field collection — this plugin has no serializable attributes.
nvinfer1::PluginFieldCollection empty_fc_{0, nullptr};
-
};
// ---------------------------------------------------------------------------
ERROR: Some files do not conform to style guidelinesThere was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-04-02 22:44:59.718079+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py 2026-04-02 22:45:19.360198+00:00
@@ -588,11 +588,10 @@
set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir)
out = gather_layer.get_output(0)
return out
-
def index_put_scatter_add_plugin(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
@@ -608,13 +607,13 @@
Supports any number of non-None index tensors (N >= 1). The plugin
inputs are laid out as: [src, idx_0, ..., idx_{N-1}, values].
"""
non_none_indices = [x for x in input_indices if x is not None]
- assert len(non_none_indices) >= 1, (
- "ScatterAdd plugin requires at least one non-None index tensor"
- )
+ assert (
+ len(non_none_indices) >= 1
+ ), "ScatterAdd plugin requires at least one non-None index tensor"
# Plugin supports float32/float16/bfloat16; cast other types through float32.
_supported_float_dtypes = (trt.float32, trt.float16, trt.bfloat16)
original_dtype = input_tensor.dtype
if original_dtype not in _supported_float_dtypes:
@@ -635,18 +634,16 @@
if not isinstance(values, ITensor):
values = get_trt_tensor(ctx, values, f"{name}_values", min_rank=1)
# Values must match src dtype after any cast above.
if values.dtype != input_tensor.dtype:
- values = cast_trt_tensor(
- ctx, values, input_tensor.dtype, f"{name}_values_cast"
- )
+ values = cast_trt_tensor(ctx, values, input_tensor.dtype, f"{name}_values_cast")
creator = trt.get_plugin_registry().get_creator("ScatterAdd", "1", "torch_tensorrt")
- assert creator is not None, (
- "ScatterAdd plugin creator not found — is torch_tensorrt_runtime loaded?"
- )
+ assert (
+ creator is not None
+ ), "ScatterAdd plugin creator not found — is torch_tensorrt_runtime loaded?"
pfc = trt.PluginFieldCollection([])
plugin = creator.create_plugin("ScatterAdd", pfc, trt.TensorRTPhase.BUILD)
assert plugin is not None, "Failed to create ScatterAdd plugin instance"
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2026-04-02 22:44:59.716077+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py 2026-04-02 22:45:20.952750+00:00
@@ -1110,11 +1110,12 @@
ScatterAdd plugin is present in the TRT plugin registry (registered
at library init time when torch_tensorrt_runtime is loaded).
Supports any number of non-None index tensors.
"""
if not (
- index_dtype_validator(node, settings) and index_nonbool_validator(node, settings)
+ index_dtype_validator(node, settings)
+ and index_nonbool_validator(node, settings)
):
return False
if not args_bounds_check(node.args, 3, False):
return False
return _scatter_add_plugin_available()
@@ -1176,12 +1177,10 @@
name,
args[0],
args[1],
args[2],
)
-
-
@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-04-02 22:44:59.743099+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-04-02 22:45:22.064973+00:00
@@ -859,11 +859,10 @@
result = trt_mod(src.clone(), values_1d, idx0, idx1)
assert torch.allclose(
result, torch_output, atol=1e-4, rtol=1e-4
), f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"
-
# ------------------------------------------------------------------
# Duplicate-index tests for realistic use-case models
# These mirror the scenarios in experiments/bench_index_put_scatter_add.py
# and verify that _index_put_scatter_add correctly accumulates into
# duplicate positions when index_put is embedded in a larger graph.
@@ -970,11 +969,11 @@
self.conv = torch.nn.Conv1d(C, 8, kernel_size=3, padding=1)
self.head = torch.nn.Linear(8, 4, bias=False)
def forward(self, signal, hist):
# signal: (1, C, L), hist: (n_bins, 8)
- feat = self.conv(signal).squeeze(0).T # (L, 8)
+ feat = self.conv(signal).squeeze(0).T # (L, 8)
hist = hist.index_put((get_bin_ids(),), feat, accumulate=True)
return self.head(hist.mean(dim=0))
signal = torch.randn(1, C, L)
hist = torch.zeros(n_bins, 8)ac26d06 to
65a6c46
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-04-03 21:29:49.036765+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-04-03 21:30:12.049600+00:00
@@ -873,11 +873,12 @@
# and verify that _index_put_scatter_add correctly accumulates into
# duplicate positions when index_put is embedded in a larger graph.
# ------------------------------------------------------------------
@pytest.mark.skipif(
- ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+ ENABLED_FEATURES.tensorrt_rtx,
+ reason="ScatterAdd plugin not available in TRT RTX",
)
def test_kv_cache_duplicate_slot_writes(self):
"""KV-cache style: linear projection → index_put(accumulate=True) into
a flat cache with duplicate slot indices → output projection.
@@ -914,11 +915,12 @@
enable_passes=True,
use_explicit_typing=True,
)
@pytest.mark.skipif(
- ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+ ENABLED_FEATURES.tensorrt_rtx,
+ reason="ScatterAdd plugin not available in TRT RTX",
)
def test_sparse_embedding_duplicate_seq_ids(self):
"""Sparse embedding accumulation: embedding lookup → index_put(accumulate=True)
into per-sequence accumulators where many tokens map to the same sequence → ReLU.
@@ -958,11 +960,12 @@
enable_passes=True,
use_explicit_typing=True,
)
@pytest.mark.skipif(
- ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+ ENABLED_FEATURES.tensorrt_rtx,
+ reason="ScatterAdd plugin not available in TRT RTX",
)
def test_histogram_conv_duplicate_bin_ids(self):
"""Histogram accumulation: Conv1d → index_put(accumulate=True) into histogram
bins where many frames land in the same bin → mean pool → linear.
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-04-03 22:43:11.354007+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-04-03 22:43:32.026211+00:00
@@ -873,11 +873,12 @@
# and verify that _index_put_scatter_add correctly accumulates into
# duplicate positions when index_put is embedded in a larger graph.
# ------------------------------------------------------------------
@pytest.mark.skipif(
- ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+ ENABLED_FEATURES.tensorrt_rtx,
+ reason="ScatterAdd plugin not available in TRT RTX",
)
def test_kv_cache_duplicate_slot_writes(self):
"""KV-cache style: linear projection → index_put(accumulate=True) into
a flat cache with duplicate slot indices → output projection.
@@ -914,11 +915,12 @@
enable_passes=True,
use_explicit_typing=True,
)
@pytest.mark.skipif(
- ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+ ENABLED_FEATURES.tensorrt_rtx,
+ reason="ScatterAdd plugin not available in TRT RTX",
)
def test_sparse_embedding_duplicate_seq_ids(self):
"""Sparse embedding accumulation: embedding lookup → index_put(accumulate=True)
into per-sequence accumulators where many tokens map to the same sequence → ReLU.
@@ -958,11 +960,12 @@
enable_passes=True,
use_explicit_typing=True,
)
@pytest.mark.skipif(
- ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+ ENABLED_FEATURES.tensorrt_rtx,
+ reason="ScatterAdd plugin not available in TRT RTX",
)
def test_histogram_conv_duplicate_bin_ids(self):
"""Histogram accumulation: Conv1d → index_put(accumulate=True) into histogram
bins where many frames land in the same bin → mean pool → linear.
25a0c4a to
c57b696
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-04-03 23:21:15.419550+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py 2026-04-03 23:21:35.379005+00:00
@@ -873,11 +873,12 @@
# and verify that _index_put_scatter_add correctly accumulates into
# duplicate positions when index_put is embedded in a larger graph.
# ------------------------------------------------------------------
@pytest.mark.skipif(
- ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+ ENABLED_FEATURES.tensorrt_rtx,
+ reason="ScatterAdd plugin not available in TRT RTX",
)
def test_kv_cache_duplicate_slot_writes(self):
"""KV-cache style: linear projection → index_put(accumulate=True) into
a flat cache with duplicate slot indices → output projection.
@@ -914,11 +915,12 @@
enable_passes=True,
use_explicit_typing=True,
)
@pytest.mark.skipif(
- ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+ ENABLED_FEATURES.tensorrt_rtx,
+ reason="ScatterAdd plugin not available in TRT RTX",
)
def test_sparse_embedding_duplicate_seq_ids(self):
"""Sparse embedding accumulation: embedding lookup → index_put(accumulate=True)
into per-sequence accumulators where many tokens map to the same sequence → ReLU.
@@ -958,11 +960,12 @@
enable_passes=True,
use_explicit_typing=True,
)
@pytest.mark.skipif(
- ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+ ENABLED_FEATURES.tensorrt_rtx,
+ reason="ScatterAdd plugin not available in TRT RTX",
)
def test_histogram_conv_duplicate_bin_ids(self):
"""Histogram accumulation: Conv1d → index_put(accumulate=True) into histogram
bins where many frames land in the same bin → mean pool → linear.
Description
Index put appears in KV cache implementations from huggingface, we get a broadcast error because there was no validator catching this. This PR adds support for dynamic shape in the op and properly guards failure modes
Fixes #4139
Fixes #4142
Fixes #3647
Fixes #2939
Fixes #3798
Fixes #3806
index_putfails with dynamic non-indexed dimensions ("Dynamic shape in free dimensions not supported")get_shapeindex_copy_/StaticCachefails with shape broadcast error ("Cannot broadcast (1,8,1,128) to (1,1,8,128)")index_add_fails with dynamic shapex[bool_3d_mask] = 0.0crashes withValueError: __len__() should return >= 0expand_boolean_indicesto correctly split TRTadd_non_zerooutput(ndim, N)into per-dim(N,)tensorsaccumulate=Truewith duplicate indices produces wrong results (scatter overwrites instead of accumulating)Accumulate is supported by a MxP implementation of native tensorrt ops which we expect to have ~10-15% overhead in reasonably sized problems but can be costly in very large tasks.
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: