Skip to content

Add dynamic shape support to index_put#4143

Open
narendasan wants to merge 9 commits intomainfrom
narendasan/push-trxznozvxnsq
Open

Add dynamic shape support to index_put#4143
narendasan wants to merge 9 commits intomainfrom
narendasan/push-trxznozvxnsq

Conversation

@narendasan
Copy link
Copy Markdown
Collaborator

@narendasan narendasan commented Mar 20, 2026

Description

Index put appears in KV cache implementations from huggingface, we get a broadcast error because there was no validator catching this. This PR adds support for dynamic shape in the op and properly guards failure modes

Fixes #4139
Fixes #4142
Fixes #3647
Fixes #2939
Fixes #3798
Fixes #3806

Issue Description Fix
#4139 index_put fails with dynamic non-indexed dimensions ("Dynamic shape in free dimensions not supported") Rewrote converter to propagate dynamic dims through get_shape
#4142 index_copy_ / StaticCache fails with shape broadcast error ("Cannot broadcast (1,8,1,128) to (1,1,8,128)") Fixed axis alignment in ND scatter index construction
#3806 index_add_ fails with dynamic shape Dynamic M/P handled in scatter-add path
#3798 Non-consecutive indices + dynamic shape not supported Converter now handles arbitrary non-contiguous index combinations with dynamic shapes
#3777 x[bool_3d_mask] = 0.0 crashes with ValueError: __len__() should return >= 0 Fixed expand_boolean_indices to correctly split TRT add_non_zero output (ndim, N) into per-dim (N,) tensors
#2939 accumulate=True with duplicate indices produces wrong results (scatter overwrites instead of accumulating)

Accumulate is supported by a MxP implementation of native tensorrt ops which we expect to have ~10-15% overhead in reasonably sized problems but can be costly in very large tasks.

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@meta-cla meta-cla bot added the cla signed label Mar 20, 2026
@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from fef7530 to 17d88b1 Compare March 20, 2026 20:13
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Mar 20, 2026
Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-03-20 20:13:55.127307+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-03-20 20:14:14.887149+00:00
@@ -623,11 +623,16 @@
    for _fi, _fdim in enumerate(F):
        _s = input_tensor.shape[_fdim]
        if _s == DYNAMIC_DIM:
            F_shape_values.append(
                get_shape(
-                    ctx, target, source_ir, f"{name}_fshape_{_fdim}", input_tensor, _fdim
+                    ctx,
+                    target,
+                    source_ir,
+                    f"{name}_fshape_{_fdim}",
+                    input_tensor,
+                    _fdim,
                )
            )
        else:
            F_shape_values.append(_s)
    _has_dynamic_f = any(isinstance(_s, TRTTensor) for _s in F_shape_values)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-03-20 20:13:55.159208+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-03-20 20:14:16.694415+00:00
@@ -357,11 +357,10 @@
        )
        result = trt_engine(source_tensor, indices_tensor, value_tensor)

        torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4)

-
    def test_kv_cache_dynamic_batch(self):
        """index_put with a dynamic free dimension (batch) — issue #4139.

        Pattern: cache[..., idx, :] = values  where dim-1 (batch) is dynamic
        and dim-2 (cache/time) is the indexed static dimension.
@@ -420,12 +419,12 @@
            use_explicit_typing=True,
            min_block_size=1,
        )

        result = trt_mod(cache.clone(), values, idx)
-        assert torch.allclose(result, torch_output, atol=1e-3, rtol=1e-3), (
-            f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}"
-        )
+        assert torch.allclose(
+            result, torch_output, atol=1e-3, rtol=1e-3
+        ), f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}"


if __name__ == "__main__":
    run_tests()

@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from 17d88b1 to 858bf17 Compare March 20, 2026 22:24
Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-03-20 22:25:06.608217+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-03-20 22:25:28.809312+00:00
@@ -392,12 +392,16 @@
        )
        trt_mod = torchtrt.dynamo.compile(
            ep,
            arg_inputs=[
                torchtrt.Input(shape=(16,), dtype=torch.float32),
-                torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32),
-                torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32),
+                torchtrt.Input(
+                    min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32
+                ),
+                torchtrt.Input(
+                    min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32
+                ),
            ],
            min_block_size=1,
        )
        result = trt_mod(src.clone(), values, idx)
        assert torch.allclose(

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-03-23 18:18:24.127073+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-03-23 18:18:42.849840+00:00
@@ -772,13 +772,15 @@
        ctx, target, source_ir, f"{name}_result_flat", src_flat, delta
    )

    # Rebuild the output shape (may contain dynamic dims)
    out_shape = tuple(
-        get_shape(ctx, target, source_ir, f"{name}_oshape_{i}", input_tensor, i)
-        if input_tensor.shape[i] == DYNAMIC_DIM
-        else input_tensor.shape[i]
+        (
+            get_shape(ctx, target, source_ir, f"{name}_oshape_{i}", input_tensor, i)
+            if input_tensor.shape[i] == DYNAMIC_DIM
+            else input_tensor.shape[i]
+        )
        for i in range(rank)
    )
    return impl.shuffle.reshape(
        ctx, target, source_ir, f"{name}_result", result_flat, out_shape
    )
@@ -871,11 +873,13 @@
    F = [i for i in range(rank) if indices[i] is None]  # Free dimensions
    I = [i for i in range(rank) if indices[i] is not None]  # Indexed dimensions
    K = len(I)
    # Determine the maximum size 'N' among the index tensors
    if K > 0:
-        index_shapes = []  # [tensor.shape[0] for tensor in indices if tensor is not None]
+        index_shapes = (
+            []
+        )  # [tensor.shape[0] for tensor in indices if tensor is not None]
        for _ni, idx_tensor in enumerate(indices):
            if idx_tensor is not None:
                if idx_tensor.shape[0] != DYNAMIC_DIM:
                    index_shapes.append(idx_tensor.shape[0])
                else:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-03-23 18:18:24.152947+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-03-23 18:18:45.151011+00:00
@@ -321,31 +321,45 @@
            # duplicate positions correctly (mirrors test_index_put_accumulate_duplicate_indices).
            param(
                test_name="1d_duplicate_indices_accumulate",
                source_tensor=torch.zeros([6], dtype=torch.float32),
                indices_tensor=(torch.tensor([0, 0, 2, 2, 2], dtype=torch.int64),),
-                value_tensor=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32),
+                value_tensor=torch.tensor(
+                    [1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32
+                ),
                accumulate=True,
            ),
            param(
                test_name="2d_indices_accumulate_True",
                source_tensor=torch.zeros([5, 5], dtype=torch.float32),
-                indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
+                indices_tensor=(
+                    torch.tensor([0, 0], dtype=torch.int32),
+                    torch.tensor([1, 1], dtype=torch.int32),
+                ),
                value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32),
                accumulate=True,
            ),
            param(
                test_name="3d_indices_accumulate_True",
                source_tensor=torch.zeros([3, 3, 3], dtype=torch.float32),
-                indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([2, 2], dtype=torch.int32)),
+                indices_tensor=(
+                    torch.tensor([0, 0], dtype=torch.int32),
+                    torch.tensor([1, 1], dtype=torch.int32),
+                    torch.tensor([2, 2], dtype=torch.int32),
+                ),
                value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32),
                accumulate=True,
            ),
            param(
                test_name="4d_indices_accumulate_True",
                source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.float32),
-                indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
+                indices_tensor=(
+                    torch.tensor([0, 0], dtype=torch.int32),
+                    torch.tensor([1, 1], dtype=torch.int32),
+                    torch.tensor([0, 0], dtype=torch.int32),
+                    torch.tensor([1, 1], dtype=torch.int32),
+                ),
                value_tensor=torch.tensor([1.0, 2.0], dtype=torch.float32),
                accumulate=True,
            ),
            # Negative indices with accumulate (mirrors test_index_put_accumulate_large_tensor).
            param(
@@ -357,29 +371,38 @@
            ),
            # bfloat16 + duplicate indices: computation stays in bfloat16 (no forced fp32 cast).
            param(
                test_name="accumulate_bfloat16_duplicate",
                source_tensor=torch.zeros([4, 4], dtype=torch.bfloat16),
-                indices_tensor=(torch.tensor([0, 0, 2], dtype=torch.int64), torch.tensor([1, 1, 3], dtype=torch.int64)),
+                indices_tensor=(
+                    torch.tensor([0, 0, 2], dtype=torch.int64),
+                    torch.tensor([1, 1, 3], dtype=torch.int64),
+                ),
                value_tensor=torch.tensor([1.0, 2.0, 4.0], dtype=torch.bfloat16),
                accumulate=True,
            ),
            # float16 + duplicate indices.
            param(
                test_name="accumulate_float16_duplicate",
                source_tensor=torch.zeros([4, 4], dtype=torch.float16),
-                indices_tensor=(torch.tensor([1, 1, 3], dtype=torch.int64), torch.tensor([0, 0, 2], dtype=torch.int64)),
+                indices_tensor=(
+                    torch.tensor([1, 1, 3], dtype=torch.int64),
+                    torch.tensor([0, 0, 2], dtype=torch.int64),
+                ),
                value_tensor=torch.tensor([2.0, 3.0, 5.0], dtype=torch.float16),
                accumulate=True,
            ),
            # Partial broadcast: one index covers a single position on dim-1 while
            # dim-0 has multiple positions — mirrors test_index_put_accumulate_expanded_values
            # (t[tensor([0,1,2,3]), tensor([1])] += 1.0).
            param(
                test_name="accumulate_partial_dim1_broadcast",
                source_tensor=torch.zeros([5, 2], dtype=torch.float32),
-                indices_tensor=(torch.tensor([0, 1, 2, 3], dtype=torch.int64), torch.tensor([1], dtype=torch.int64)),
+                indices_tensor=(
+                    torch.tensor([0, 1, 2, 3], dtype=torch.int64),
+                    torch.tensor([1], dtype=torch.int64),
+                ),
                value_tensor=torch.tensor([1.0], dtype=torch.float32),
                accumulate=True,
            ),
        ]
    )
@@ -512,12 +535,16 @@
        )
        trt_mod = torchtrt.dynamo.compile(
            ep,
            arg_inputs=[
                torchtrt.Input(shape=(16,), dtype=torch.float32),
-                torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32),
-                torchtrt.Input(min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32),
+                torchtrt.Input(
+                    min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.float32
+                ),
+                torchtrt.Input(
+                    min_shape=(1,), opt_shape=(3,), max_shape=(16,), dtype=torch.int32
+                ),
            ],
            min_block_size=1,
        )
        result = trt_mod(src.clone(), values, idx)
        assert torch.allclose(
@@ -587,11 +614,10 @@

        result = trt_mod(cache.clone(), values, idx)
        assert torch.allclose(
            result, torch_output, atol=1e-3, rtol=1e-3
        ), f"KV-cache index_put mismatch: max diff = {(result - torch_output).abs().max()}"
-

    def test_accumulate_random_walk_duplicate_indices(self):
        """accumulate=True on 1-D input where indices are generated by a random walk
        (many duplicates interleaved).  Mirrors PyTorch's
        test_index_put_accumulate_duplicate_indices, scaled to a TRT-friendly size.
@@ -664,13 +690,13 @@
                torchtrt.Input(shape=(1, 2), dtype=torch.int64),
            ],
            min_block_size=1,
        )
        result = trt_mod(src.clone(), values, i0, i1, i2)
-        assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
-            f"3D expand accumulate mismatch: max diff = {(result - torch_output).abs().max()}"
-        )
+        assert torch.allclose(
+            result, torch_output, atol=1e-4, rtol=1e-4
+        ), f"3D expand accumulate mismatch: max diff = {(result - torch_output).abs().max()}"

    def test_empty_index_no_op(self):
        """index_put with an empty index tensor is a no-op — output equals input.

        Mirrors PyTorch's test_empty_index: x[empty_idx] = values leaves x unchanged.
@@ -697,13 +723,13 @@
                torchtrt.Input(shape=(0,), dtype=torch.int64),
            ],
            min_block_size=1,
        )
        result = trt_mod(src.clone(), values, idx)
-        assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
-            f"Empty-index no-op mismatch: {result} vs {torch_output}"
-        )
+        assert torch.allclose(
+            result, torch_output, atol=1e-4, rtol=1e-4
+        ), f"Empty-index no-op mismatch: {result} vs {torch_output}"

    def test_index_ind_dtype_int_vs_long(self):
        """int32 and int64 index tensors must produce identical results.

        Mirrors PyTorch's test_index_ind_dtype.
@@ -725,27 +751,41 @@
        assert torch.allclose(ref_long, ref_int), "CPU int32 vs int64 mismatch"

        ep_long = torch.export.export(model, args=(src, values, idx_long))
        ep_int = torch.export.export(model, args=(src, values, idx_int))

-        trt_long = torchtrt.dynamo.compile(ep_long, arg_inputs=[
-            torchtrt.Input(shape=(4, 4), dtype=torch.float32),
-            torchtrt.Input(shape=(4,), dtype=torch.float32),
-            torchtrt.Input(shape=(4,), dtype=torch.int64),
-        ], min_block_size=1)
-        trt_int = torchtrt.dynamo.compile(ep_int, arg_inputs=[
-            torchtrt.Input(shape=(4, 4), dtype=torch.float32),
-            torchtrt.Input(shape=(4,), dtype=torch.float32),
-            torchtrt.Input(shape=(4,), dtype=torch.int32),
-        ], min_block_size=1)
+        trt_long = torchtrt.dynamo.compile(
+            ep_long,
+            arg_inputs=[
+                torchtrt.Input(shape=(4, 4), dtype=torch.float32),
+                torchtrt.Input(shape=(4,), dtype=torch.float32),
+                torchtrt.Input(shape=(4,), dtype=torch.int64),
+            ],
+            min_block_size=1,
+        )
+        trt_int = torchtrt.dynamo.compile(
+            ep_int,
+            arg_inputs=[
+                torchtrt.Input(shape=(4, 4), dtype=torch.float32),
+                torchtrt.Input(shape=(4,), dtype=torch.float32),
+                torchtrt.Input(shape=(4,), dtype=torch.int32),
+            ],
+            min_block_size=1,
+        )

        out_long = trt_long(src.clone(), values, idx_long)
        out_int = trt_int(src.clone(), values, idx_int)

-        assert torch.allclose(out_long, ref_long, atol=1e-4, rtol=1e-4), "TRT int64 mismatch"
-        assert torch.allclose(out_int, ref_int, atol=1e-4, rtol=1e-4), "TRT int32 mismatch"
-        assert torch.allclose(out_long, out_int, atol=1e-4, rtol=1e-4), "TRT int32 vs int64 inconsistency"
+        assert torch.allclose(
+            out_long, ref_long, atol=1e-4, rtol=1e-4
+        ), "TRT int64 mismatch"
+        assert torch.allclose(
+            out_int, ref_int, atol=1e-4, rtol=1e-4
+        ), "TRT int32 mismatch"
+        assert torch.allclose(
+            out_long, out_int, atol=1e-4, rtol=1e-4
+        ), "TRT int32 vs int64 inconsistency"

    def test_accumulate_non_contiguous_source(self):
        """accumulate=True on a non-contiguous (sliced) source tensor.

        Mirrors PyTorch's test_index_put_accumulate_non_contiguous.
@@ -780,26 +820,24 @@
                torchtrt.Input(shape=(2,), dtype=torch.int64),
            ],
            min_block_size=1,
        )
        result = trt_mod(src_slice.contiguous(), values, idx)
-        assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
-            f"Non-contiguous accumulate mismatch: max diff = {(result - torch_output).abs().max()}"
-        )
+        assert torch.allclose(
+            result, torch_output, atol=1e-4, rtol=1e-4
+        ), f"Non-contiguous accumulate mismatch: max diff = {(result - torch_output).abs().max()}"

    def test_accumulate_expanded_values_broadcast(self):
        """accumulate=True with value broadcasting — 0D scalar and 1D values
        broadcast across unique indexed positions.

        Mirrors PyTorch's test_index_put_accumulate_expanded_values (unique indices only).
        """

        class AccumBroadcast(torch.nn.Module):
            def forward(self, src, values, idx0, idx1):
-                return torch.ops.aten.index_put.default(
-                    src, [idx0, idx1], values, True
-                )
+                return torch.ops.aten.index_put.default(src, [idx0, idx1], values, True)

        src = torch.zeros(5, 2, dtype=torch.float32, device="cuda")
        idx0 = torch.tensor([0, 1, 2, 3], dtype=torch.int64, device="cuda")
        idx1 = torch.tensor([0, 1, 0, 1], dtype=torch.int64, device="cuda")
        values_1d = torch.tensor([1.0], dtype=torch.float32, device="cuda")
@@ -817,12 +855,12 @@
                torchtrt.Input(shape=(4,), dtype=torch.int64),
            ],
            min_block_size=1,
        )
        result = trt_mod(src.clone(), values_1d, idx0, idx1)
-        assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4), (
-            f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"
-        )
+        assert torch.allclose(
+            result, torch_output, atol=1e-4, rtol=1e-4
+        ), f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"


if __name__ == "__main__":
    run_tests()

@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from 8eb87a5 to 3402e30 Compare March 23, 2026 18:19
@narendasan narendasan requested a review from zewenli98 March 23, 2026 18:25
@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from 3402e30 to 24f99f4 Compare March 24, 2026 15:39
Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-03-24 15:40:02.378996+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-03-24 15:40:21.128137+00:00
@@ -873,11 +873,13 @@
    F = [i for i in range(rank) if indices[i] is None]  # Free dimensions
    I = [i for i in range(rank) if indices[i] is not None]  # Indexed dimensions
    K = len(I)
    # Determine the maximum size 'N' among the index tensors
    if K > 0:
-        index_shapes = []  # [tensor.shape[0] for tensor in indices if tensor is not None]
+        index_shapes = (
+            []
+        )  # [tensor.shape[0] for tensor in indices if tensor is not None]
        for _ni, idx_tensor in enumerate(indices):
            if idx_tensor is not None:
                if idx_tensor.shape[0] != DYNAMIC_DIM:
                    index_shapes.append(idx_tensor.shape[0])
                else:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py	2026-03-24 15:40:02.394519+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py	2026-03-24 15:40:22.366634+00:00
@@ -158,11 +158,13 @@
                out = torch.ops.aten.index.Tensor(x, indices)
                return out

        input = torch.randn(2, 2)
        index0 = torch.tensor([True, False])
-        self.run_test(TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True)
+        self.run_test(
+            TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True
+        )

    def test_index_zero_index_three_dim_ITensor(self):
        class TestModule(nn.Module):
            def forward(self, x, index0):
                indices = [None, index0, None]
@@ -170,11 +172,13 @@
                return out

        input = torch.randn(2, 2, 2)
        index0 = torch.randint(0, 1, (1, 1))
        index0 = index0.to(torch.int32)
-        self.run_test(TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True)
+        self.run_test(
+            TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True
+        )

    @unittest.skipIf(
        ENABLED_FEATURES.tensorrt_rtx,
        "Skipped on tensorrt_rtx due to nonzero not supported",
    )
@@ -185,11 +189,13 @@
                out = torch.ops.aten.index.Tensor(x, indices)
                return out

        input = torch.randn(2, 2, 2)
        index0 = torch.tensor([True, False])
-        self.run_test(TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True)
+        self.run_test(
+            TestModule(), [input, index0], use_dynamo_tracer=True, enable_passes=True
+        )


class TestIndexDynamicConstantConverter(DispatchTestCase):
    @parameterized.expand(
        [

@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from 24f99f4 to 503ca51 Compare March 24, 2026 16:05
Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_embed_engines.py	2026-03-24 18:29:04.478330+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_embed_engines.py	2026-03-24 18:29:24.473246+00:00
@@ -14,13 +14,11 @@
except (ImportError, RuntimeError):
    HAS_TORCHVISION = False


class TestModelToEngineToModel(unittest.TestCase):
-    @unittest.skipIf(
-        not HAS_TORCHVISION, "torchvision not available"
-    )
+    @unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
    @unittest.skipIf(
        torchtrt.ENABLED_FEATURES.tensorrt_rtx,
        "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
    )
    def test_resnet50(self):
--- /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_module_fallback.py	2026-03-24 18:29:04.478330+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_module_fallback.py	2026-03-24 18:29:24.488458+00:00
@@ -14,13 +14,11 @@

@unittest.skipIf(
    torchtrt.ENABLED_FEATURES.tensorrt_rtx,
    "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
-@unittest.skipIf(
-    not HAS_TORCHVISION, "torchvision not available"
-)
+@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
class TestModuleFallback(unittest.TestCase):
    def test_fallback_resnet18(self):
        self.model = models.resnet18(pretrained=True).eval().to("cuda")
        self.input = torch.randn((1, 3, 224, 224)).to("cuda")
        compile_spec = {
--- /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_operator_fallback.py	2026-03-24 18:29:04.478330+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_operator_fallback.py	2026-03-24 18:29:24.518305+00:00
@@ -14,13 +14,11 @@

@unittest.skipIf(
    torchtrt.ENABLED_FEATURES.tensorrt_rtx,
    "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
-@unittest.skipIf(
-    not HAS_TORCHVISION, "torchvision not available"
-)
+@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
class TestFallbackModels(unittest.TestCase):
    def test_fallback_resnet18(self):
        self.model = models.resnet18(pretrained=True).eval().to("cuda")
        self.input = torch.randn((1, 3, 224, 224)).to("cuda")
        compile_spec = {
--- /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_e2e_behavior.py	2026-03-24 18:29:04.478330+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_e2e_behavior.py	2026-03-24 18:29:24.569646+00:00
@@ -16,13 +16,11 @@

@unittest.skipIf(
    torchtrt.ENABLED_FEATURES.tensorrt_rtx,
    "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
-@unittest.skipIf(
-    not HAS_TORCHVISION, "torchvision not available"
-)
+@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
class TestInputTypeDefaultsFP32Model(unittest.TestCase):

    def test_input_use_default_fp32(self):
        self.model = models.resnet18(pretrained=True).eval().to("cuda")
        self.input = torch.randn((1, 3, 224, 224)).to("cuda")
@@ -68,13 +66,11 @@
class TestInputTypeDefaultsFP16Model(unittest.TestCase):
    @unittest.skipIf(
        torchtrt.ENABLED_FEATURES.tensorrt_rtx,
        "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
    )
-    @unittest.skipIf(
-        not HAS_TORCHVISION, "torchvision not available"
-    )
+    @unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
    def test_input_use_default_fp16(self):
        self.model = models.resnet18(pretrained=True).eval().to("cuda")
        self.input = torch.randn((1, 3, 224, 224)).to("cuda")

        half_mod = torch.jit.script(self.model)
@@ -89,13 +85,11 @@

    @unittest.skipIf(
        torchtrt.ENABLED_FEATURES.tensorrt_rtx,
        "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
    )
-    @unittest.skipIf(
-        not HAS_TORCHVISION, "torchvision not available"
-    )
+    @unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
    def test_input_use_default_fp16_without_fp16_enabled(self):
        self.model = models.resnet18(pretrained=True).eval().to("cuda")
        self.input = torch.randn((1, 3, 224, 224)).to("cuda")

        half_mod = torch.jit.script(self.model)
@@ -108,13 +102,11 @@

    @unittest.skipIf(
        torchtrt.ENABLED_FEATURES.tensorrt_rtx,
        "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
    )
-    @unittest.skipIf(
-        not HAS_TORCHVISION, "torchvision not available"
-    )
+    @unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
    def test_input_respect_user_setting_fp16_weights_fp32_in(self):
        self.model = models.resnet18(pretrained=True).eval().to("cuda")
        self.input = torch.randn((1, 3, 224, 224)).to("cuda")

        half_mod = torch.jit.script(self.model)
@@ -130,13 +122,11 @@

    @unittest.skipIf(
        torchtrt.ENABLED_FEATURES.tensorrt_rtx,
        "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
    )
-    @unittest.skipIf(
-        not HAS_TORCHVISION, "torchvision not available"
-    )
+    @unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
    def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self):
        self.model = models.resnet18(pretrained=True).eval().to("cuda")
        self.input = torch.randn((1, 3, 224, 224)).to("cuda")

        half_mod = torch.jit.script(self.model)
--- /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_ts_backend.py	2026-03-24 18:29:04.478485+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/ts/api/test_ts_backend.py	2026-03-24 18:29:24.694311+00:00
@@ -12,13 +12,11 @@
    HAS_TORCHVISION = True
except (ImportError, RuntimeError):
    HAS_TORCHVISION = False


-@unittest.skipIf(
-    not HAS_TORCHVISION, "torchvision not available"
-)
+@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
class TestCompile(unittest.TestCase):
    def test_compile_traced(self):
        self.model = models.vgg16(pretrained=True).eval().to("cuda")
        self.input = torch.randn((1, 3, 224, 224)).to("cuda")
        self.traced_model = torch.jit.trace(self.model, [self.input])
@@ -129,13 +127,11 @@

@unittest.skipIf(
    torchtrt.ENABLED_FEATURES.tensorrt_rtx,
    "aten::adaptive_avg_pool2d is implemented via plugins which is not supported for tensorrt_rtx",
)
-@unittest.skipIf(
-    not HAS_TORCHVISION, "torchvision not available"
-)
+@unittest.skipIf(not HAS_TORCHVISION, "torchvision not available")
class TestCheckMethodOpSupport(unittest.TestCase):
    def test_check_support(self):
        module = models.alexnet(pretrained=True).eval().to("cuda")
        self.module = torch.jit.trace(module, torch.ones((1, 3, 224, 224)).to("cuda"))

@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from 4352d78 to fb2f692 Compare March 25, 2026 03:41
github-actions[bot]

This comment was marked as outdated.

@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from fb2f692 to 314ec6c Compare March 30, 2026 20:38
Copy link
Copy Markdown
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good to me. My only concern for _index_put_scatter_add is the perf. Is it faster than leaving it in pytorch?

Comment on lines +10 to +11
# reduction mode. The current implementation (scatter into zeros + elementwise
# add) only gives correct results when every scattered index is unique.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new implementation should have supported the cases with duplicated indices right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, will add a test for it

@narendasan
Copy link
Copy Markdown
Collaborator Author

The changes look good to me. My only concern for _index_put_scatter_add is the perf. Is it faster than leaving it in pytorch?

In rough tests it seemed competitive but i didnt have a good model on hand that uses it. If you have a reference model I can run it.

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.cpp b/tmp/changes.txt
index fbf81a9..e7a3620 100644
--- a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.cpp
+++ b/tmp/changes.txt
@@ -15,7 +15,6 @@ namespace impl {
// Helpers
// ---------------------------------------------------------------------------

-
// ---------------------------------------------------------------------------
// ScatterAddPlugin — construction
// ---------------------------------------------------------------------------
@@ -26,8 +25,7 @@ ScatterAddPlugin::ScatterAddPlugin() = default;
// IPluginV3
// ---------------------------------------------------------------------------

-nvinfer1::IPluginCapability* ScatterAddPlugin::getCapabilityInterface(
-    nvinfer1::PluginCapabilityType type) noexcept {
+nvinfer1::IPluginCapability* ScatterAddPlugin::getCapabilityInterface(nvinfer1::PluginCapabilityType type) noexcept {
  switch (type) {
    case nvinfer1::PluginCapabilityType::kCORE:
      return static_cast<nvinfer1::IPluginV3OneCore*>(this);
@@ -105,15 +103,13 @@ bool ScatterAddPlugin::supportsFormatCombination(

  // Positions 1 through nbInputs-2 are index tensors: int32 or int64.
  if (pos >= 1 && pos <= nbInputs - 2) {
-    return desc.desc.type == nvinfer1::DataType::kINT32 ||
-           desc.desc.type == nvinfer1::DataType::kINT64;
+    return desc.desc.type == nvinfer1::DataType::kINT32 || desc.desc.type == nvinfer1::DataType::kINT64;
  }

  // pos 0 (src), pos nbInputs-1 (values), pos nbInputs (output):
  // float32 / float16 / bfloat16, all sharing the same type.
-  const bool float_type = desc.desc.type == nvinfer1::DataType::kFLOAT ||
-                          desc.desc.type == nvinfer1::DataType::kHALF ||
-                          desc.desc.type == nvinfer1::DataType::kBF16;
+  const bool float_type = desc.desc.type == nvinfer1::DataType::kFLOAT || desc.desc.type == nvinfer1::DataType::kHALF ||
+      desc.desc.type == nvinfer1::DataType::kBF16;
  if (!float_type) {
    return false;
  }
@@ -131,7 +127,7 @@ int32_t ScatterAddPlugin::configurePlugin(
    const nvinfer1::DynamicPluginTensorDesc* /*out*/,
    int32_t /*nbOutputs*/) noexcept {
  dtype_ = in[0].desc.type;
-  n_indices_ = nbInputs - 2;  // exclude src and values
+  n_indices_ = nbInputs - 2; // exclude src and values
  idx_dtypes_.resize(n_indices_);
  for (int i = 0; i < n_indices_; ++i) {
    idx_dtypes_[i] = in[1 + i].desc.type;
@@ -180,8 +176,7 @@ int32_t ScatterAddPlugin::enqueue(
  const auto float_opts = at::TensorOptions().device(at::kCUDA).dtype(float_dtype);

  at::Tensor src = at::from_blob(const_cast<void*>(inputs[0]), src_shape_, float_opts);
-  at::Tensor val = at::from_blob(
-      const_cast<void*>(inputs[n_indices_ + 1]), val_shape_, float_opts);
+  at::Tensor val = at::from_blob(const_cast<void*>(inputs[n_indices_ + 1]), val_shape_, float_opts);

  // Build the indices list — one entry per index tensor, all cast to int64
  // as required by ATen's index_put kernel.
@@ -190,8 +185,7 @@ int32_t ScatterAddPlugin::enqueue(
  for (int i = 0; i < n_indices_; ++i) {
    const at::ScalarType int_dtype = util::TRTDataTypeToScalarType(idx_dtypes_[i]);
    const auto int_opts = at::TensorOptions().device(at::kCUDA).dtype(int_dtype);
-    at::Tensor idx = at::from_blob(
-        const_cast<void*>(inputs[1 + i]), idx_shapes_[i], int_opts);
+    at::Tensor idx = at::from_blob(const_cast<void*>(inputs[1 + i]), idx_shapes_[i], int_opts);
    indices.push_back(std::optional<at::Tensor>(idx.to(torch::kLong)));
  }

@@ -222,8 +216,7 @@ int32_t ScatterAddPlugin::enqueue(
  return 0;
}

-nvinfer1::IPluginV3* ScatterAddPlugin::attachToContext(
-    nvinfer1::IPluginResourceContext* /*context*/) noexcept {
+nvinfer1::IPluginV3* ScatterAddPlugin::attachToContext(nvinfer1::IPluginResourceContext* /*context*/) noexcept {
  return clone();
}

diff --git a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.h b/tmp/changes.txt
index 0585f22..f882874 100644
--- a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.h
+++ b/tmp/changes.txt
@@ -110,15 +110,14 @@ class ScatterAddPlugin : public nvinfer1::IPluginV3,
  // Captured at configurePlugin / onShapeChange time.
  // Layout: input[0]=src, input[1..n_indices_]=indices, input[n_indices_+1]=values
  nvinfer1::DataType dtype_{nvinfer1::DataType::kFLOAT};
-  std::vector<nvinfer1::DataType> idx_dtypes_;  // one per index input
+  std::vector<nvinfer1::DataType> idx_dtypes_; // one per index input
  std::vector<int64_t> src_shape_;
  std::vector<int64_t> val_shape_;
-  std::vector<std::vector<int64_t>> idx_shapes_;  // full shape of each index tensor
-  int32_t n_indices_{0};            // number of index tensors
+  std::vector<std::vector<int64_t>> idx_shapes_; // full shape of each index tensor
+  int32_t n_indices_{0}; // number of index tensors

  // Empty field collection — this plugin has no serializable attributes.
  nvinfer1::PluginFieldCollection empty_fc_{0, nullptr};
-
};

// ---------------------------------------------------------------------------
ERROR: Some files do not conform to style guidelines

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-04-02 22:43:56.836073+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-04-02 22:44:20.017903+00:00
@@ -588,11 +588,10 @@
    set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir)
    out = gather_layer.get_output(0)
    return out


-
def index_put_scatter_add_plugin(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
@@ -608,13 +607,13 @@

    Supports any number of non-None index tensors (N >= 1).  The plugin
    inputs are laid out as: [src, idx_0, ..., idx_{N-1}, values].
    """
    non_none_indices = [x for x in input_indices if x is not None]
-    assert len(non_none_indices) >= 1, (
-        "ScatterAdd plugin requires at least one non-None index tensor"
-    )
+    assert (
+        len(non_none_indices) >= 1
+    ), "ScatterAdd plugin requires at least one non-None index tensor"

    # Plugin supports float32/float16/bfloat16; cast other types through float32.
    _supported_float_dtypes = (trt.float32, trt.float16, trt.bfloat16)
    original_dtype = input_tensor.dtype
    if original_dtype not in _supported_float_dtypes:
@@ -635,18 +634,16 @@
    if not isinstance(values, ITensor):
        values = get_trt_tensor(ctx, values, f"{name}_values", min_rank=1)

    # Values must match src dtype after any cast above.
    if values.dtype != input_tensor.dtype:
-        values = cast_trt_tensor(
-            ctx, values, input_tensor.dtype, f"{name}_values_cast"
-        )
+        values = cast_trt_tensor(ctx, values, input_tensor.dtype, f"{name}_values_cast")

    creator = trt.get_plugin_registry().get_creator("ScatterAdd", "1", "torch_tensorrt")
-    assert creator is not None, (
-        "ScatterAdd plugin creator not found — is torch_tensorrt_runtime loaded?"
-    )
+    assert (
+        creator is not None
+    ), "ScatterAdd plugin creator not found — is torch_tensorrt_runtime loaded?"

    pfc = trt.PluginFieldCollection([])
    plugin = creator.create_plugin("ScatterAdd", pfc, trt.TensorRTPhase.BUILD)
    assert plugin is not None, "Failed to create ScatterAdd plugin instance"

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2026-04-02 22:43:56.833500+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2026-04-02 22:44:21.995858+00:00
@@ -1110,11 +1110,12 @@
    ScatterAdd plugin is present in the TRT plugin registry (registered
    at library init time when torch_tensorrt_runtime is loaded).
    Supports any number of non-None index tensors.
    """
    if not (
-        index_dtype_validator(node, settings) and index_nonbool_validator(node, settings)
+        index_dtype_validator(node, settings)
+        and index_nonbool_validator(node, settings)
    ):
        return False
    if not args_bounds_check(node.args, 3, False):
        return False
    return _scatter_add_plugin_available()
@@ -1176,12 +1177,10 @@
        name,
        args[0],
        args[1],
        args[2],
    )
-
-


@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor, supports_dynamic_shapes=True)
@enforce_tensor_types(
    {
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-04-02 22:43:56.864073+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-04-02 22:44:23.373273+00:00
@@ -859,11 +859,10 @@
        result = trt_mod(src.clone(), values_1d, idx0, idx1)
        assert torch.allclose(
            result, torch_output, atol=1e-4, rtol=1e-4
        ), f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"

-
    # ------------------------------------------------------------------
    # Duplicate-index tests for realistic use-case models
    # These mirror the scenarios in experiments/bench_index_put_scatter_add.py
    # and verify that _index_put_scatter_add correctly accumulates into
    # duplicate positions when index_put is embedded in a larger graph.
@@ -970,11 +969,11 @@
                self.conv = torch.nn.Conv1d(C, 8, kernel_size=3, padding=1)
                self.head = torch.nn.Linear(8, 4, bias=False)

            def forward(self, signal, hist):
                # signal: (1, C, L), hist: (n_bins, 8)
-                feat = self.conv(signal).squeeze(0).T    # (L, 8)
+                feat = self.conv(signal).squeeze(0).T  # (L, 8)
                hist = hist.index_put((get_bin_ids(),), feat, accumulate=True)
                return self.head(hist.mean(dim=0))

        signal = torch.randn(1, C, L)
        hist = torch.zeros(n_bins, 8)

@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from c6e3a52 to ac26d06 Compare April 2, 2026 22:44
Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.cpp b/tmp/changes.txt
index 038667c..ac7c321 100644
--- a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.cpp
+++ b/tmp/changes.txt
@@ -11,11 +11,9 @@ namespace core {
namespace plugins {
namespace impl {

-
ScatterAddPlugin::ScatterAddPlugin() = default;

-nvinfer1::IPluginCapability* ScatterAddPlugin::getCapabilityInterface(
-    nvinfer1::PluginCapabilityType type) noexcept {
+nvinfer1::IPluginCapability* ScatterAddPlugin::getCapabilityInterface(nvinfer1::PluginCapabilityType type) noexcept {
  switch (type) {
    case nvinfer1::PluginCapabilityType::kCORE:
      return static_cast<nvinfer1::IPluginV3OneCore*>(this);
@@ -93,15 +91,13 @@ bool ScatterAddPlugin::supportsFormatCombination(

  // Positions 1 through nbInputs-2 are index tensors: int32 or int64.
  if (pos >= 1 && pos <= nbInputs - 2) {
-    return desc.desc.type == nvinfer1::DataType::kINT32 ||
-           desc.desc.type == nvinfer1::DataType::kINT64;
+    return desc.desc.type == nvinfer1::DataType::kINT32 || desc.desc.type == nvinfer1::DataType::kINT64;
  }

  // pos 0 (src), pos nbInputs-1 (values), pos nbInputs (output):
  // float32 / float16 / bfloat16, all sharing the same type.
-  const bool float_type = desc.desc.type == nvinfer1::DataType::kFLOAT ||
-                          desc.desc.type == nvinfer1::DataType::kHALF ||
-                          desc.desc.type == nvinfer1::DataType::kBF16;
+  const bool float_type = desc.desc.type == nvinfer1::DataType::kFLOAT || desc.desc.type == nvinfer1::DataType::kHALF ||
+      desc.desc.type == nvinfer1::DataType::kBF16;
  if (!float_type) {
    return false;
  }
@@ -119,7 +115,7 @@ int32_t ScatterAddPlugin::configurePlugin(
    const nvinfer1::DynamicPluginTensorDesc* /*out*/,
    int32_t /*nbOutputs*/) noexcept {
  dtype_ = in[0].desc.type;
-  n_indices_ = nbInputs - 2;  // exclude src and values
+  n_indices_ = nbInputs - 2; // exclude src and values
  idx_dtypes_.resize(n_indices_);
  for (int i = 0; i < n_indices_; ++i) {
    idx_dtypes_[i] = in[1 + i].desc.type;
@@ -168,8 +164,7 @@ int32_t ScatterAddPlugin::enqueue(
  const auto float_opts = at::TensorOptions().device(at::kCUDA).dtype(float_dtype);

  at::Tensor src = at::from_blob(const_cast<void*>(inputs[0]), src_shape_, float_opts);
-  at::Tensor val = at::from_blob(
-      const_cast<void*>(inputs[n_indices_ + 1]), val_shape_, float_opts);
+  at::Tensor val = at::from_blob(const_cast<void*>(inputs[n_indices_ + 1]), val_shape_, float_opts);

  // Build the indices list — one entry per index tensor, all cast to int64
  // as required by ATen's index_put kernel.
@@ -178,8 +173,7 @@ int32_t ScatterAddPlugin::enqueue(
  for (int i = 0; i < n_indices_; ++i) {
    const at::ScalarType int_dtype = util::TRTDataTypeToScalarType(idx_dtypes_[i]);
    const auto int_opts = at::TensorOptions().device(at::kCUDA).dtype(int_dtype);
-    at::Tensor idx = at::from_blob(
-        const_cast<void*>(inputs[1 + i]), idx_shapes_[i], int_opts);
+    at::Tensor idx = at::from_blob(const_cast<void*>(inputs[1 + i]), idx_shapes_[i], int_opts);
    indices.push_back(std::optional<at::Tensor>(idx.to(torch::kLong)));
  }

@@ -210,8 +204,7 @@ int32_t ScatterAddPlugin::enqueue(
  return 0;
}

-nvinfer1::IPluginV3* ScatterAddPlugin::attachToContext(
-    nvinfer1::IPluginResourceContext* /*context*/) noexcept {
+nvinfer1::IPluginV3* ScatterAddPlugin::attachToContext(nvinfer1::IPluginResourceContext* /*context*/) noexcept {
  return clone();
}

diff --git a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.h b/tmp/changes.txt
index 0585f22..f882874 100644
--- a/home/runner/work/TensorRT/TensorRT/core/plugins/impl/scatter_add_plugin.h
+++ b/tmp/changes.txt
@@ -110,15 +110,14 @@ class ScatterAddPlugin : public nvinfer1::IPluginV3,
  // Captured at configurePlugin / onShapeChange time.
  // Layout: input[0]=src, input[1..n_indices_]=indices, input[n_indices_+1]=values
  nvinfer1::DataType dtype_{nvinfer1::DataType::kFLOAT};
-  std::vector<nvinfer1::DataType> idx_dtypes_;  // one per index input
+  std::vector<nvinfer1::DataType> idx_dtypes_; // one per index input
  std::vector<int64_t> src_shape_;
  std::vector<int64_t> val_shape_;
-  std::vector<std::vector<int64_t>> idx_shapes_;  // full shape of each index tensor
-  int32_t n_indices_{0};            // number of index tensors
+  std::vector<std::vector<int64_t>> idx_shapes_; // full shape of each index tensor
+  int32_t n_indices_{0}; // number of index tensors

  // Empty field collection — this plugin has no serializable attributes.
  nvinfer1::PluginFieldCollection empty_fc_{0, nullptr};
-
};

// ---------------------------------------------------------------------------
ERROR: Some files do not conform to style guidelines

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-04-02 22:44:59.718079+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2026-04-02 22:45:19.360198+00:00
@@ -588,11 +588,10 @@
    set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir)
    out = gather_layer.get_output(0)
    return out


-
def index_put_scatter_add_plugin(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
@@ -608,13 +607,13 @@

    Supports any number of non-None index tensors (N >= 1).  The plugin
    inputs are laid out as: [src, idx_0, ..., idx_{N-1}, values].
    """
    non_none_indices = [x for x in input_indices if x is not None]
-    assert len(non_none_indices) >= 1, (
-        "ScatterAdd plugin requires at least one non-None index tensor"
-    )
+    assert (
+        len(non_none_indices) >= 1
+    ), "ScatterAdd plugin requires at least one non-None index tensor"

    # Plugin supports float32/float16/bfloat16; cast other types through float32.
    _supported_float_dtypes = (trt.float32, trt.float16, trt.bfloat16)
    original_dtype = input_tensor.dtype
    if original_dtype not in _supported_float_dtypes:
@@ -635,18 +634,16 @@
    if not isinstance(values, ITensor):
        values = get_trt_tensor(ctx, values, f"{name}_values", min_rank=1)

    # Values must match src dtype after any cast above.
    if values.dtype != input_tensor.dtype:
-        values = cast_trt_tensor(
-            ctx, values, input_tensor.dtype, f"{name}_values_cast"
-        )
+        values = cast_trt_tensor(ctx, values, input_tensor.dtype, f"{name}_values_cast")

    creator = trt.get_plugin_registry().get_creator("ScatterAdd", "1", "torch_tensorrt")
-    assert creator is not None, (
-        "ScatterAdd plugin creator not found — is torch_tensorrt_runtime loaded?"
-    )
+    assert (
+        creator is not None
+    ), "ScatterAdd plugin creator not found — is torch_tensorrt_runtime loaded?"

    pfc = trt.PluginFieldCollection([])
    plugin = creator.create_plugin("ScatterAdd", pfc, trt.TensorRTPhase.BUILD)
    assert plugin is not None, "Failed to create ScatterAdd plugin instance"

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2026-04-02 22:44:59.716077+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2026-04-02 22:45:20.952750+00:00
@@ -1110,11 +1110,12 @@
    ScatterAdd plugin is present in the TRT plugin registry (registered
    at library init time when torch_tensorrt_runtime is loaded).
    Supports any number of non-None index tensors.
    """
    if not (
-        index_dtype_validator(node, settings) and index_nonbool_validator(node, settings)
+        index_dtype_validator(node, settings)
+        and index_nonbool_validator(node, settings)
    ):
        return False
    if not args_bounds_check(node.args, 3, False):
        return False
    return _scatter_add_plugin_available()
@@ -1176,12 +1177,10 @@
        name,
        args[0],
        args[1],
        args[2],
    )
-
-


@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor, supports_dynamic_shapes=True)
@enforce_tensor_types(
    {
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-04-02 22:44:59.743099+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-04-02 22:45:22.064973+00:00
@@ -859,11 +859,10 @@
        result = trt_mod(src.clone(), values_1d, idx0, idx1)
        assert torch.allclose(
            result, torch_output, atol=1e-4, rtol=1e-4
        ), f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"

-
    # ------------------------------------------------------------------
    # Duplicate-index tests for realistic use-case models
    # These mirror the scenarios in experiments/bench_index_put_scatter_add.py
    # and verify that _index_put_scatter_add correctly accumulates into
    # duplicate positions when index_put is embedded in a larger graph.
@@ -970,11 +969,11 @@
                self.conv = torch.nn.Conv1d(C, 8, kernel_size=3, padding=1)
                self.head = torch.nn.Linear(8, 4, bias=False)

            def forward(self, signal, hist):
                # signal: (1, C, L), hist: (n_bins, 8)
-                feat = self.conv(signal).squeeze(0).T    # (L, 8)
+                feat = self.conv(signal).squeeze(0).T  # (L, 8)
                hist = hist.index_put((get_bin_ids(),), feat, accumulate=True)
                return self.head(hist.mean(dim=0))

        signal = torch.randn(1, C, L)
        hist = torch.zeros(n_bins, 8)

@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from ac26d06 to 65a6c46 Compare April 2, 2026 22:46
Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-04-03 21:29:49.036765+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-04-03 21:30:12.049600+00:00
@@ -873,11 +873,12 @@
    # and verify that _index_put_scatter_add correctly accumulates into
    # duplicate positions when index_put is embedded in a larger graph.
    # ------------------------------------------------------------------

    @pytest.mark.skipif(
-        ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+        ENABLED_FEATURES.tensorrt_rtx,
+        reason="ScatterAdd plugin not available in TRT RTX",
    )
    def test_kv_cache_duplicate_slot_writes(self):
        """KV-cache style: linear projection → index_put(accumulate=True) into
        a flat cache with duplicate slot indices → output projection.

@@ -914,11 +915,12 @@
            enable_passes=True,
            use_explicit_typing=True,
        )

    @pytest.mark.skipif(
-        ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+        ENABLED_FEATURES.tensorrt_rtx,
+        reason="ScatterAdd plugin not available in TRT RTX",
    )
    def test_sparse_embedding_duplicate_seq_ids(self):
        """Sparse embedding accumulation: embedding lookup → index_put(accumulate=True)
        into per-sequence accumulators where many tokens map to the same sequence → ReLU.

@@ -958,11 +960,12 @@
            enable_passes=True,
            use_explicit_typing=True,
        )

    @pytest.mark.skipif(
-        ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+        ENABLED_FEATURES.tensorrt_rtx,
+        reason="ScatterAdd plugin not available in TRT RTX",
    )
    def test_histogram_conv_duplicate_bin_ids(self):
        """Histogram accumulation: Conv1d → index_put(accumulate=True) into histogram
        bins where many frames land in the same bin → mean pool → linear.

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-04-03 22:43:11.354007+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-04-03 22:43:32.026211+00:00
@@ -873,11 +873,12 @@
    # and verify that _index_put_scatter_add correctly accumulates into
    # duplicate positions when index_put is embedded in a larger graph.
    # ------------------------------------------------------------------

    @pytest.mark.skipif(
-        ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+        ENABLED_FEATURES.tensorrt_rtx,
+        reason="ScatterAdd plugin not available in TRT RTX",
    )
    def test_kv_cache_duplicate_slot_writes(self):
        """KV-cache style: linear projection → index_put(accumulate=True) into
        a flat cache with duplicate slot indices → output projection.

@@ -914,11 +915,12 @@
            enable_passes=True,
            use_explicit_typing=True,
        )

    @pytest.mark.skipif(
-        ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+        ENABLED_FEATURES.tensorrt_rtx,
+        reason="ScatterAdd plugin not available in TRT RTX",
    )
    def test_sparse_embedding_duplicate_seq_ids(self):
        """Sparse embedding accumulation: embedding lookup → index_put(accumulate=True)
        into per-sequence accumulators where many tokens map to the same sequence → ReLU.

@@ -958,11 +960,12 @@
            enable_passes=True,
            use_explicit_typing=True,
        )

    @pytest.mark.skipif(
-        ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+        ENABLED_FEATURES.tensorrt_rtx,
+        reason="ScatterAdd plugin not available in TRT RTX",
    )
    def test_histogram_conv_duplicate_bin_ids(self):
        """Histogram accumulation: Conv1d → index_put(accumulate=True) into histogram
        bins where many frames land in the same bin → mean pool → linear.

@narendasan narendasan force-pushed the narendasan/push-trxznozvxnsq branch from 25a0c4a to c57b696 Compare April 3, 2026 23:20
Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-04-03 23:21:15.419550+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_put_aten.py	2026-04-03 23:21:35.379005+00:00
@@ -873,11 +873,12 @@
    # and verify that _index_put_scatter_add correctly accumulates into
    # duplicate positions when index_put is embedded in a larger graph.
    # ------------------------------------------------------------------

    @pytest.mark.skipif(
-        ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+        ENABLED_FEATURES.tensorrt_rtx,
+        reason="ScatterAdd plugin not available in TRT RTX",
    )
    def test_kv_cache_duplicate_slot_writes(self):
        """KV-cache style: linear projection → index_put(accumulate=True) into
        a flat cache with duplicate slot indices → output projection.

@@ -914,11 +915,12 @@
            enable_passes=True,
            use_explicit_typing=True,
        )

    @pytest.mark.skipif(
-        ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+        ENABLED_FEATURES.tensorrt_rtx,
+        reason="ScatterAdd plugin not available in TRT RTX",
    )
    def test_sparse_embedding_duplicate_seq_ids(self):
        """Sparse embedding accumulation: embedding lookup → index_put(accumulate=True)
        into per-sequence accumulators where many tokens map to the same sequence → ReLU.

@@ -958,11 +960,12 @@
            enable_passes=True,
            use_explicit_typing=True,
        )

    @pytest.mark.skipif(
-        ENABLED_FEATURES.tensorrt_rtx, reason="ScatterAdd plugin not available in TRT RTX"
+        ENABLED_FEATURES.tensorrt_rtx,
+        reason="ScatterAdd plugin not available in TRT RTX",
    )
    def test_histogram_conv_duplicate_bin_ids(self):
        """Histogram accumulation: Conv1d → index_put(accumulate=True) into histogram
        bins where many frames land in the same bin → mean pool → linear.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests

Projects

None yet

2 participants