root@2907986418dc:/workspace# python concat_bfloat_bug.py
CUDA 13 is not currently supported for TRT-LLM plugins. Please install pytorch with CUDA 12.x support
/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden.
Overriding a previously registered kernel for the same operator and the same dispatch key
operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922
dispatch key: ADInplaceOrView
previous kernel: no debug info
new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
self.m.impl(
/usr/local/lib/python3.12/dist-packages/modelopt/torch/utils/import_utils.py:32: UserWarning: Failed to import transformers plugin due to: RuntimeError("Failed to import transformers.trainer because of the following error (look up to see its traceback):\ncannot import name 'amp' from 'apex' (/usr/local/lib/python3.12/dist-packages/apex/__init__.py)"). You may ignore this warning if you do not need this plugin.
warnings.warn(
/usr/local/lib/python3.12/dist-packages/modelopt/torch/utils/import_utils.py:32: UserWarning: Failed to import huggingface plugin due to: ImportError("cannot import name 'ModelOptHFTrainer' from 'modelopt.torch.opt.plugins' (/usr/local/lib/python3.12/dist-packages/modelopt/torch/opt/plugins/__init__.py)"). You may ignore this warning if you do not need this plugin.
warnings.warn(
/usr/local/lib/python3.12/dist-packages/modelopt/torch/utils/import_utils.py:32: UserWarning: Failed to import transformers trainer plugin due to: ImportError("cannot import name 'ModelOptHFTrainer' from 'modelopt.torch.opt.plugins' (/usr/local/lib/python3.12/dist-packages/modelopt/torch/opt/plugins/__init__.py)"). You may ignore this warning if you do not need this plugin.
warnings.warn(
WARNING:py.warnings:/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
Traceback (most recent call last):
File "/workspace/concat_bfloat_bug.py", line 40, in <module>
trt_gm = torch_tensorrt.dynamo.compile(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 716, in compile
trt_gm = compile_module(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/_compiler.py", line 955, in compile_module
trt_module = convert_module(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 147, in convert_module
serialized_interpreter_result = interpret_module_to_result(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 78, in interpret_module_to_result
interpreter_result = interpreter.run()
^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 705, in run
self._construct_trt_network_def()
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 417, in _construct_trt_network_def
super().run()
File "/usr/local/lib/python3.12/dist-packages/torch/fx/interpreter.py", line 174, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 774, in run_node
trt_node: torch.fx.Node = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/interpreter.py", line 256, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 881, in call_function
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 229, in aten_ops_cat
return impl.cat.cat(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/impl/cat.py", line 124, in cat
return unify_and_concat_trt_tensors(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch_tensorrt/dynamo/conversion/impl/cat.py", line 49, in unify_and_concat_trt_tensors
const_arr = np.array([x], dtype=np.int32)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 1254, in __array__
return self.numpy().astype(dtype, copy=False)
^^^^^^^^^^^^
TypeError: Got unsupported ScalarType BFloat16
While executing %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%_frozen_param0, %x], 1), kwargs = {})
Original traceback:
File "/workspace/concat_bfloat_bug.py", line 18, in forward
x = torch.cat([padding, x], dim=1)
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)```
## Expected behavior
The attached minimal reproducible example should compile.
I'm using `nvidia/pytorch:25.12-py3` container with the pre installed `torch_tensorrt-2.10.0a0`
> Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): torch_tensorrt-2.10.0a0
- PyTorch Version (e.g. 1.0): torch-2.10.0a0
- GPU models and configuration: H200
Bug Description
unify_and_concat_trt_tensors tries to create an numpy array with bfloat16 dtype
maybe connected to 4036
To Reproduce
Steps to reproduce the behavior:
concat_bfloat_bug.py