-
Notifications
You must be signed in to change notification settings - Fork 382
fix: Refactor the cat converter and seperate out the mixed use #4059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_cat_aten.py 2026-01-29 01:57:26.838331+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_cat_aten.py 2026-01-29 01:58:08.317313+00:00
@@ -141,11 +141,13 @@
"""Test cat with mixed float32 and float16 tensors - should promote to float32"""
class MixedDtypeCat(nn.Module):
def __init__(self):
super().__init__()
- self.register_buffer("const_fp16", torch.ones(2, 3, dtype=torch.float16))
+ self.register_buffer(
+ "const_fp16", torch.ones(2, 3, dtype=torch.float16)
+ )
def forward(self, x):
# x is float32, const_fp16 is float16, result should be float32
return torch.ops.aten.cat.default((self.const_fp16, x), 0)
@@ -177,12 +179,16 @@
"""Test cat with three different dtypes - bfloat16, float16, float32"""
class ThreeDtypeCat(nn.Module):
def __init__(self):
super().__init__()
- self.register_buffer("const_bf16", torch.ones(2, 3, dtype=torch.bfloat16))
- self.register_buffer("const_fp16", torch.ones(2, 3, dtype=torch.float16))
+ self.register_buffer(
+ "const_bf16", torch.ones(2, 3, dtype=torch.bfloat16)
+ )
+ self.register_buffer(
+ "const_fp16", torch.ones(2, 3, dtype=torch.float16)
+ )
def forward(self, x):
# bf16, fp16, fp32 -> should promote to fp32
return torch.ops.aten.cat.default(
(self.const_bf16, self.const_fp16, x), 0
@@ -197,11 +203,13 @@
def test_cat_many_tensors(self):
"""Test cat with many tensors (10+)"""
class ManyCat(nn.Module):
def forward(self, t0, t1, t2, t3, t4, t5, t6, t7, t8, t9):
- return torch.ops.aten.cat.default((t0, t1, t2, t3, t4, t5, t6, t7, t8, t9), 0)
+ return torch.ops.aten.cat.default(
+ (t0, t1, t2, t3, t4, t5, t6, t7, t8, t9), 0
+ )
# Create 10 small tensors
inputs = [torch.randn(1, 3, device="cuda") for _ in range(10)]
self.run_test(
ManyCat(),
@@ -339,16 +347,22 @@
class CatBF16Constants(nn.Module):
def __init__(self):
super().__init__()
# Register multiple bf16 constant buffers
- self.register_buffer("bf16_const1", torch.ones(2, 3, dtype=torch.bfloat16))
- self.register_buffer("bf16_const2", torch.full((2, 3), 2.0, dtype=torch.bfloat16))
+ self.register_buffer(
+ "bf16_const1", torch.ones(2, 3, dtype=torch.bfloat16)
+ )
+ self.register_buffer(
+ "bf16_const2", torch.full((2, 3), 2.0, dtype=torch.bfloat16)
+ )
def forward(self, x):
# Cat bf16 input with bf16 constants - output should be bf16
- return torch.ops.aten.cat.default((self.bf16_const1, x, self.bf16_const2), 0)
+ return torch.ops.aten.cat.default(
+ (self.bf16_const1, x, self.bf16_const2), 0
+ )
inputs = [torch.randn(2, 3, device="cuda", dtype=torch.bfloat16)]
self.run_test(
CatBF16Constants(),
inputs,There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_cat_aten.py 2026-01-29 01:57:27.444459+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_cat_aten.py 2026-01-29 01:58:10.895001+00:00
@@ -141,11 +141,13 @@
"""Test cat with mixed float32 and float16 tensors - should promote to float32"""
class MixedDtypeCat(nn.Module):
def __init__(self):
super().__init__()
- self.register_buffer("const_fp16", torch.ones(2, 3, dtype=torch.float16))
+ self.register_buffer(
+ "const_fp16", torch.ones(2, 3, dtype=torch.float16)
+ )
def forward(self, x):
# x is float32, const_fp16 is float16, result should be float32
return torch.ops.aten.cat.default((self.const_fp16, x), 0)
@@ -177,12 +179,16 @@
"""Test cat with three different dtypes - bfloat16, float16, float32"""
class ThreeDtypeCat(nn.Module):
def __init__(self):
super().__init__()
- self.register_buffer("const_bf16", torch.ones(2, 3, dtype=torch.bfloat16))
- self.register_buffer("const_fp16", torch.ones(2, 3, dtype=torch.float16))
+ self.register_buffer(
+ "const_bf16", torch.ones(2, 3, dtype=torch.bfloat16)
+ )
+ self.register_buffer(
+ "const_fp16", torch.ones(2, 3, dtype=torch.float16)
+ )
def forward(self, x):
# bf16, fp16, fp32 -> should promote to fp32
return torch.ops.aten.cat.default(
(self.const_bf16, self.const_fp16, x), 0
@@ -197,11 +203,13 @@
def test_cat_many_tensors(self):
"""Test cat with many tensors (10+)"""
class ManyCat(nn.Module):
def forward(self, t0, t1, t2, t3, t4, t5, t6, t7, t8, t9):
- return torch.ops.aten.cat.default((t0, t1, t2, t3, t4, t5, t6, t7, t8, t9), 0)
+ return torch.ops.aten.cat.default(
+ (t0, t1, t2, t3, t4, t5, t6, t7, t8, t9), 0
+ )
# Create 10 small tensors
inputs = [torch.randn(1, 3, device="cuda") for _ in range(10)]
self.run_test(
ManyCat(),
@@ -339,16 +347,22 @@
class CatBF16Constants(nn.Module):
def __init__(self):
super().__init__()
# Register multiple bf16 constant buffers
- self.register_buffer("bf16_const1", torch.ones(2, 3, dtype=torch.bfloat16))
- self.register_buffer("bf16_const2", torch.full((2, 3), 2.0, dtype=torch.bfloat16))
+ self.register_buffer(
+ "bf16_const1", torch.ones(2, 3, dtype=torch.bfloat16)
+ )
+ self.register_buffer(
+ "bf16_const2", torch.full((2, 3), 2.0, dtype=torch.bfloat16)
+ )
def forward(self, x):
# Cat bf16 input with bf16 constants - output should be bf16
- return torch.ops.aten.cat.default((self.bf16_const1, x, self.bf16_const2), 0)
+ return torch.ops.aten.cat.default(
+ (self.bf16_const1, x, self.bf16_const2), 0
+ )
inputs = [torch.randn(2, 3, device="cuda", dtype=torch.bfloat16)]
self.run_test(
CatBF16Constants(),
inputs,a43daa2 to
fee1de9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_cat_aten.py 2026-01-29 02:00:49.706928+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_cat_aten.py 2026-01-29 02:01:29.172215+00:00
@@ -141,11 +141,13 @@
"""Test cat with mixed float32 and float16 tensors - should promote to float32"""
class MixedDtypeCat(nn.Module):
def __init__(self):
super().__init__()
- self.register_buffer("const_fp16", torch.ones(2, 3, dtype=torch.float16))
+ self.register_buffer(
+ "const_fp16", torch.ones(2, 3, dtype=torch.float16)
+ )
def forward(self, x):
# x is float32, const_fp16 is float16, result should be float32
return torch.ops.aten.cat.default((self.const_fp16, x), 0)
@@ -177,12 +179,16 @@
"""Test cat with three different dtypes - bfloat16, float16, float32"""
class ThreeDtypeCat(nn.Module):
def __init__(self):
super().__init__()
- self.register_buffer("const_bf16", torch.ones(2, 3, dtype=torch.bfloat16))
- self.register_buffer("const_fp16", torch.ones(2, 3, dtype=torch.float16))
+ self.register_buffer(
+ "const_bf16", torch.ones(2, 3, dtype=torch.bfloat16)
+ )
+ self.register_buffer(
+ "const_fp16", torch.ones(2, 3, dtype=torch.float16)
+ )
def forward(self, x):
# bf16, fp16, fp32 -> should promote to fp32
return torch.ops.aten.cat.default(
(self.const_bf16, self.const_fp16, x), 0
@@ -197,11 +203,13 @@
def test_cat_many_tensors(self):
"""Test cat with many tensors (10+)"""
class ManyCat(nn.Module):
def forward(self, t0, t1, t2, t3, t4, t5, t6, t7, t8, t9):
- return torch.ops.aten.cat.default((t0, t1, t2, t3, t4, t5, t6, t7, t8, t9), 0)
+ return torch.ops.aten.cat.default(
+ (t0, t1, t2, t3, t4, t5, t6, t7, t8, t9), 0
+ )
# Create 10 small tensors
inputs = [torch.randn(1, 3, device="cuda") for _ in range(10)]
self.run_test(
ManyCat(),
@@ -339,16 +347,22 @@
class CatBF16Constants(nn.Module):
def __init__(self):
super().__init__()
# Register multiple bf16 constant buffers
- self.register_buffer("bf16_const1", torch.ones(2, 3, dtype=torch.bfloat16))
- self.register_buffer("bf16_const2", torch.full((2, 3), 2.0, dtype=torch.bfloat16))
+ self.register_buffer(
+ "bf16_const1", torch.ones(2, 3, dtype=torch.bfloat16)
+ )
+ self.register_buffer(
+ "bf16_const2", torch.full((2, 3), 2.0, dtype=torch.bfloat16)
+ )
def forward(self, x):
# Cat bf16 input with bf16 constants - output should be bf16
- return torch.ops.aten.cat.default((self.bf16_const1, x, self.bf16_const2), 0)
+ return torch.ops.aten.cat.default(
+ (self.bf16_const1, x, self.bf16_const2), 0
+ )
inputs = [torch.randn(2, 3, device="cuda", dtype=torch.bfloat16)]
self.run_test(
CatBF16Constants(),
inputs,fee1de9 to
fd4cfae
Compare
fd4cfae to
d50c951
Compare
Description
There was a reported issue #4037 that highlighted an issue where the cat converter was falling into a code path that AFAICT is intended to manufacture Shape Tensors for dynamic shape upsampling. This PR refactors the cat converter to separate these two use cases so that its a bit easier to understand what is happening. Also adds a bunch of test cases
Fixes #4037
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: