Skip to content

Conversation

@narendasan
Copy link
Collaborator

Description

There was a reported issue #4037 that highlighted an issue where the cat converter was falling into a code path that AFAICT is intended to manufacture Shape Tensors for dynamic shape upsampling. This PR refactors the cat converter to separate these two use cases so that its a bit easier to understand what is happening. Also adds a bunch of test cases

Fixes #4037

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@narendasan narendasan requested a review from apbose January 29, 2026 01:57
@meta-cla meta-cla bot added the cla signed label Jan 29, 2026
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jan 29, 2026
@github-actions github-actions bot requested a review from cehongwang January 29, 2026 01:57
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_cat_aten.py	2026-01-29 01:57:26.838331+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_cat_aten.py	2026-01-29 01:58:08.317313+00:00
@@ -141,11 +141,13 @@
        """Test cat with mixed float32 and float16 tensors - should promote to float32"""

        class MixedDtypeCat(nn.Module):
            def __init__(self):
                super().__init__()
-                self.register_buffer("const_fp16", torch.ones(2, 3, dtype=torch.float16))
+                self.register_buffer(
+                    "const_fp16", torch.ones(2, 3, dtype=torch.float16)
+                )

            def forward(self, x):
                # x is float32, const_fp16 is float16, result should be float32
                return torch.ops.aten.cat.default((self.const_fp16, x), 0)

@@ -177,12 +179,16 @@
        """Test cat with three different dtypes - bfloat16, float16, float32"""

        class ThreeDtypeCat(nn.Module):
            def __init__(self):
                super().__init__()
-                self.register_buffer("const_bf16", torch.ones(2, 3, dtype=torch.bfloat16))
-                self.register_buffer("const_fp16", torch.ones(2, 3, dtype=torch.float16))
+                self.register_buffer(
+                    "const_bf16", torch.ones(2, 3, dtype=torch.bfloat16)
+                )
+                self.register_buffer(
+                    "const_fp16", torch.ones(2, 3, dtype=torch.float16)
+                )

            def forward(self, x):
                # bf16, fp16, fp32 -> should promote to fp32
                return torch.ops.aten.cat.default(
                    (self.const_bf16, self.const_fp16, x), 0
@@ -197,11 +203,13 @@
    def test_cat_many_tensors(self):
        """Test cat with many tensors (10+)"""

        class ManyCat(nn.Module):
            def forward(self, t0, t1, t2, t3, t4, t5, t6, t7, t8, t9):
-                return torch.ops.aten.cat.default((t0, t1, t2, t3, t4, t5, t6, t7, t8, t9), 0)
+                return torch.ops.aten.cat.default(
+                    (t0, t1, t2, t3, t4, t5, t6, t7, t8, t9), 0
+                )

        # Create 10 small tensors
        inputs = [torch.randn(1, 3, device="cuda") for _ in range(10)]
        self.run_test(
            ManyCat(),
@@ -339,16 +347,22 @@

        class CatBF16Constants(nn.Module):
            def __init__(self):
                super().__init__()
                # Register multiple bf16 constant buffers
-                self.register_buffer("bf16_const1", torch.ones(2, 3, dtype=torch.bfloat16))
-                self.register_buffer("bf16_const2", torch.full((2, 3), 2.0, dtype=torch.bfloat16))
+                self.register_buffer(
+                    "bf16_const1", torch.ones(2, 3, dtype=torch.bfloat16)
+                )
+                self.register_buffer(
+                    "bf16_const2", torch.full((2, 3), 2.0, dtype=torch.bfloat16)
+                )

            def forward(self, x):
                # Cat bf16 input with bf16 constants - output should be bf16
-                return torch.ops.aten.cat.default((self.bf16_const1, x, self.bf16_const2), 0)
+                return torch.ops.aten.cat.default(
+                    (self.bf16_const1, x, self.bf16_const2), 0
+                )

        inputs = [torch.randn(2, 3, device="cuda", dtype=torch.bfloat16)]
        self.run_test(
            CatBF16Constants(),
            inputs,

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_cat_aten.py	2026-01-29 01:57:27.444459+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_cat_aten.py	2026-01-29 01:58:10.895001+00:00
@@ -141,11 +141,13 @@
        """Test cat with mixed float32 and float16 tensors - should promote to float32"""

        class MixedDtypeCat(nn.Module):
            def __init__(self):
                super().__init__()
-                self.register_buffer("const_fp16", torch.ones(2, 3, dtype=torch.float16))
+                self.register_buffer(
+                    "const_fp16", torch.ones(2, 3, dtype=torch.float16)
+                )

            def forward(self, x):
                # x is float32, const_fp16 is float16, result should be float32
                return torch.ops.aten.cat.default((self.const_fp16, x), 0)

@@ -177,12 +179,16 @@
        """Test cat with three different dtypes - bfloat16, float16, float32"""

        class ThreeDtypeCat(nn.Module):
            def __init__(self):
                super().__init__()
-                self.register_buffer("const_bf16", torch.ones(2, 3, dtype=torch.bfloat16))
-                self.register_buffer("const_fp16", torch.ones(2, 3, dtype=torch.float16))
+                self.register_buffer(
+                    "const_bf16", torch.ones(2, 3, dtype=torch.bfloat16)
+                )
+                self.register_buffer(
+                    "const_fp16", torch.ones(2, 3, dtype=torch.float16)
+                )

            def forward(self, x):
                # bf16, fp16, fp32 -> should promote to fp32
                return torch.ops.aten.cat.default(
                    (self.const_bf16, self.const_fp16, x), 0
@@ -197,11 +203,13 @@
    def test_cat_many_tensors(self):
        """Test cat with many tensors (10+)"""

        class ManyCat(nn.Module):
            def forward(self, t0, t1, t2, t3, t4, t5, t6, t7, t8, t9):
-                return torch.ops.aten.cat.default((t0, t1, t2, t3, t4, t5, t6, t7, t8, t9), 0)
+                return torch.ops.aten.cat.default(
+                    (t0, t1, t2, t3, t4, t5, t6, t7, t8, t9), 0
+                )

        # Create 10 small tensors
        inputs = [torch.randn(1, 3, device="cuda") for _ in range(10)]
        self.run_test(
            ManyCat(),
@@ -339,16 +347,22 @@

        class CatBF16Constants(nn.Module):
            def __init__(self):
                super().__init__()
                # Register multiple bf16 constant buffers
-                self.register_buffer("bf16_const1", torch.ones(2, 3, dtype=torch.bfloat16))
-                self.register_buffer("bf16_const2", torch.full((2, 3), 2.0, dtype=torch.bfloat16))
+                self.register_buffer(
+                    "bf16_const1", torch.ones(2, 3, dtype=torch.bfloat16)
+                )
+                self.register_buffer(
+                    "bf16_const2", torch.full((2, 3), 2.0, dtype=torch.bfloat16)
+                )

            def forward(self, x):
                # Cat bf16 input with bf16 constants - output should be bf16
-                return torch.ops.aten.cat.default((self.bf16_const1, x, self.bf16_const2), 0)
+                return torch.ops.aten.cat.default(
+                    (self.bf16_const1, x, self.bf16_const2), 0
+                )

        inputs = [torch.randn(2, 3, device="cuda", dtype=torch.bfloat16)]
        self.run_test(
            CatBF16Constants(),
            inputs,

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_cat_aten.py	2026-01-29 02:00:49.706928+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_cat_aten.py	2026-01-29 02:01:29.172215+00:00
@@ -141,11 +141,13 @@
        """Test cat with mixed float32 and float16 tensors - should promote to float32"""

        class MixedDtypeCat(nn.Module):
            def __init__(self):
                super().__init__()
-                self.register_buffer("const_fp16", torch.ones(2, 3, dtype=torch.float16))
+                self.register_buffer(
+                    "const_fp16", torch.ones(2, 3, dtype=torch.float16)
+                )

            def forward(self, x):
                # x is float32, const_fp16 is float16, result should be float32
                return torch.ops.aten.cat.default((self.const_fp16, x), 0)

@@ -177,12 +179,16 @@
        """Test cat with three different dtypes - bfloat16, float16, float32"""

        class ThreeDtypeCat(nn.Module):
            def __init__(self):
                super().__init__()
-                self.register_buffer("const_bf16", torch.ones(2, 3, dtype=torch.bfloat16))
-                self.register_buffer("const_fp16", torch.ones(2, 3, dtype=torch.float16))
+                self.register_buffer(
+                    "const_bf16", torch.ones(2, 3, dtype=torch.bfloat16)
+                )
+                self.register_buffer(
+                    "const_fp16", torch.ones(2, 3, dtype=torch.float16)
+                )

            def forward(self, x):
                # bf16, fp16, fp32 -> should promote to fp32
                return torch.ops.aten.cat.default(
                    (self.const_bf16, self.const_fp16, x), 0
@@ -197,11 +203,13 @@
    def test_cat_many_tensors(self):
        """Test cat with many tensors (10+)"""

        class ManyCat(nn.Module):
            def forward(self, t0, t1, t2, t3, t4, t5, t6, t7, t8, t9):
-                return torch.ops.aten.cat.default((t0, t1, t2, t3, t4, t5, t6, t7, t8, t9), 0)
+                return torch.ops.aten.cat.default(
+                    (t0, t1, t2, t3, t4, t5, t6, t7, t8, t9), 0
+                )

        # Create 10 small tensors
        inputs = [torch.randn(1, 3, device="cuda") for _ in range(10)]
        self.run_test(
            ManyCat(),
@@ -339,16 +347,22 @@

        class CatBF16Constants(nn.Module):
            def __init__(self):
                super().__init__()
                # Register multiple bf16 constant buffers
-                self.register_buffer("bf16_const1", torch.ones(2, 3, dtype=torch.bfloat16))
-                self.register_buffer("bf16_const2", torch.full((2, 3), 2.0, dtype=torch.bfloat16))
+                self.register_buffer(
+                    "bf16_const1", torch.ones(2, 3, dtype=torch.bfloat16)
+                )
+                self.register_buffer(
+                    "bf16_const2", torch.full((2, 3), 2.0, dtype=torch.bfloat16)
+                )

            def forward(self, x):
                # Cat bf16 input with bf16 constants - output should be bf16
-                return torch.ops.aten.cat.default((self.bf16_const1, x, self.bf16_const2), 0)
+                return torch.ops.aten.cat.default(
+                    (self.bf16_const1, x, self.bf16_const2), 0
+                )

        inputs = [torch.randn(2, 3, device="cuda", dtype=torch.bfloat16)]
        self.run_test(
            CatBF16Constants(),
            inputs,

@apbose apbose force-pushed the push-toozssvllqvl branch from fd4cfae to d50c951 Compare February 9, 2026 19:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛 [Bug] unify_and_concat_trt_tensors doesn't handle bfloat16s correctly

2 participants