Skip to content

🐛 [Bug] unify_and_concat_trt_tensors doesn't handle none flat concats correctly #4036

@hayk-24

Description

@hayk-24

Bug Description

unify_and_concat_trt_tensors wrongly flattens the concat operations left side which results in errors downstream.
maybe connected to 4035

To Reproduce

Steps to reproduce the behavior:

  1. use nvidia/pytorch:25.12-py3
  2. run python concat_bug.py
  3. we get shapes mismatch because unify_and_concat_trt_tensors flattens the lhs of the concat
root@2907986418dc:/workspace# python concat_bug.py
CUDA 13 is not currently supported for TRT-LLM plugins. Please install pytorch with CUDA 12.x support
/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
    registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922
  dispatch key: ADInplaceOrView
  previous kernel: no debug info
       new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
  self.m.impl(
/usr/local/lib/python3.12/dist-packages/modelopt/torch/utils/import_utils.py:32: UserWarning: Failed to import transformers plugin due to: RuntimeError("Failed to import transformers.trainer because of the following error (look up to see its traceback):\ncannot import name 'amp' from 'apex' (/usr/local/lib/python3.12/dist-packages/apex/__init__.py)"). You may ignore this warning if you do not need this plugin.
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/modelopt/torch/utils/import_utils.py:32: UserWarning: Failed to import huggingface plugin due to: ImportError("cannot import name 'ModelOptHFTrainer' from 'modelopt.torch.opt.plugins' (/usr/local/lib/python3.12/dist-packages/modelopt/torch/opt/plugins/__init__.py)"). You may ignore this warning if you do not need this plugin.
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/modelopt/torch/utils/import_utils.py:32: UserWarning: Failed to import transformers trainer plugin due to: ImportError("cannot import name 'ModelOptHFTrainer' from 'modelopt.torch.opt.plugins' (/usr/local/lib/python3.12/dist-packages/modelopt/torch/opt/plugins/__init__.py)"). You may ignore this warning if you do not need this plugin.
  warnings.warn(
WARNING:py.warnings:/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)

ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: API Usage Error (Error while computing output extent of node [CONCATENATION]-[unknown_ir_ops.cat.default]-[/cat_concat]. In dispatchComputeOutputExtents at /_src/optimizer/shapeof/graphShapeAnalyzer.cpp:1715)
ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [CONCATENATION]-[unknown_ir_ops.cat.default]-[/cat_concat]. In needTypeAndDimensions at /_src/optimizer/shapeof/graphShapeAnalyzer.cpp:2994)
ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildEngineWithConfig: Error Code 4: Internal Error ([CONCATENATION]-[unknown_ir_ops.cat.default]-[/cat_concat]: all concat input tensors must have the same number of dimensions. Input 0 shape: [300]. Input 1 shape: [3,100]. In estimateOutputDims at /_src/optimizer/api/layers/concatenationLayer.cpp:77)
Traceback (most recent call last):
  File "/workspace/concat_bug.py", line 40, in <module>
    trt_gm = torch_tensorrt.dynamo.compile(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 716, in compile
    trt_gm = compile_module(

Expected behavior

The attached minimal reproducible example should compile.

concat_bug.py

I'm using nvidia/pytorch:25.12-py3 container with the pre installed torch_tensorrt-2.10.0a0

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): torch_tensorrt-2.10.0a0
  • PyTorch Version (e.g. 1.0): torch-2.10.0a0
  • GPU models and configuration: H200

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingstory: Dynamo Frontend & Partitioningtorch.compile, torch.export, FX graph tracing, graph partitioner, graph breaks, and the Dynamo-to-TR

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions