Skip to content

🐛 [Bug] unify_and_concat_trt_tensors doesn't handle bfloat16s correctly #4037

@hayk-24

Description

@hayk-24

Bug Description

unify_and_concat_trt_tensors tries to create an numpy array with bfloat16 dtype
maybe connected to 4036

To Reproduce

Steps to reproduce the behavior:

  1. use nvidia/pytorch:25.12-py3
  2. run python concat_bfloat_bug.py

concat_bfloat_bug.py

root@2907986418dc:/workspace# python concat_bfloat_bug.py 
CUDA 13 is not currently supported for TRT-LLM plugins. Please install pytorch with CUDA 12.x support
/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
    registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922
  dispatch key: ADInplaceOrView
  previous kernel: no debug info
       new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
  self.m.impl(
/usr/local/lib/python3.12/dist-packages/modelopt/torch/utils/import_utils.py:32: UserWarning: Failed to import transformers plugin due to: RuntimeError("Failed to import transformers.trainer because of the following error (look up to see its traceback):\ncannot import name 'amp' from 'apex' (/usr/local/lib/python3.12/dist-packages/apex/__init__.py)"). You may ignore this warning if you do not need this plugin.
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/modelopt/torch/utils/import_utils.py:32: UserWarning: Failed to import huggingface plugin due to: ImportError("cannot import name 'ModelOptHFTrainer' from 'modelopt.torch.opt.plugins' (/usr/local/lib/python3.12/dist-packages/modelopt/torch/opt/plugins/__init__.py)"). You may ignore this warning if you do not need this plugin.
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/modelopt/torch/utils/import_utils.py:32: UserWarning: Failed to import transformers trainer plugin due to: ImportError("cannot import name 'ModelOptHFTrainer' from 'modelopt.torch.opt.plugins' (/usr/local/lib/python3.12/dist-packages/modelopt/torch/opt/plugins/__init__.py)"). You may ignore this warning if you do not need this plugin.
  warnings.warn(
WARNING:py.warnings:/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)

Traceback (most recent call last):
  File "/workspace/concat_bfloat_bug.py", line 40, in <module>
    trt_gm = torch_tensorrt.dynamo.compile(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 716, in compile
    trt_gm = compile_module(
             ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 955, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 147, in convert_module
    serialized_interpreter_result = interpret_module_to_result(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 78, in interpret_module_to_result
    interpreter_result = interpreter.run()
                         ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 705, in run
    self._construct_trt_network_def()
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 417, in _construct_trt_network_def
    super().run()
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 774, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
                              ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/interpreter.py", line 256, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 881, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 229, in aten_ops_cat
    return impl.cat.cat(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/impl/cat.py", line 124, in cat
    return unify_and_concat_trt_tensors(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/impl/cat.py", line 49, in unify_and_concat_trt_tensors
    const_arr = np.array([x], dtype=np.int32)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 1254, in __array__
    return self.numpy().astype(dtype, copy=False)
           ^^^^^^^^^^^^
TypeError: Got unsupported ScalarType BFloat16

While executing %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%_frozen_param0, %x], 1), kwargs = {})
Original traceback:
File "/workspace/concat_bfloat_bug.py", line 18, in forward
    x = torch.cat([padding, x], dim=1)
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)```
## Expected behavior

The attached minimal reproducible example should compile.

I'm using `nvidia/pytorch:25.12-py3` container with the pre installed `torch_tensorrt-2.10.0a0`

> Build information about Torch-TensorRT can be found by turning on debug messages

 - Torch-TensorRT Version (e.g. 1.0.0): torch_tensorrt-2.10.0a0
 - PyTorch Version (e.g. 1.0): torch-2.10.0a0
 - GPU models and configuration: H200

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions