- Notifications
You must be signed in to change notification settings - Fork 376
Description
Bug Description
Passing a boolean value inside a dict to kwarg_inputs parameter of the torch_tensorrt.compile method results in
ValueError: Invalid input type <class 'bool'> encountered in the dynamo_compile input parsing. Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict} It seems that apart from collection types (list, tuple, dict), at leaf level only torch.Tensor values are allowed. This contradicts the documentation https://pytorch.org/TensorRT/py_api/torch_tensorrt.html?highlight=compile which states:
kwarg_inputs: Optional[dict[Any, Any]] = None
To Reproduce
Steps to reproduce the behavior:
- Execute the following minimal example:
import torch import torch_tensorrt class TestModel(torch.nn.Module): def forward(self, param1, additional_param = bool | None): pass compiled_model = torch_tensorrt.compile( TestModel(), ir="dynamo", inputs=[torch.rand(1)], kwarg_inputs={ "additional_param": True }, ) - The result is
Traceback (most recent call last): File "...\test_bug.py", line 8, in <module> compiled_model = torch_tensorrt.compile( File "...\lib\site-packages\torch_tensorrt\_compile.py", line 284, in compile torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs) File "...\lib\site-packages\torch_tensorrt\dynamo\utils.py", line 272, in prepare_inputs torchtrt_input = prepare_inputs( File "...\lib\site-packages\torch_tensorrt\dynamo\utils.py", line 280, in prepare_inputs raise ValueError( ValueError: Invalid input type <class 'bool'> encountered in the dynamo_compile input parsing. Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict} Expected behavior
The minimal example should compile fine. Any values in addition to torch tensors in both - inputs and kwarg_inputs - should IMHO be accepted. It would additionally be nice if the documentation would be a bit more verbose about this IMHO important topic of how inputs will be treated by the compiler and what will happen at runtime of the compiled model.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
I am sorry, I do not know a canonical way of "turning on debug messages" in python. I do not know how this translates into something actionable.
- Torch-TensorRT Version (e.g. 1.0.0) / PyTorch Version (e.g. 1.0):
tensorrt==10.7.0 tensorrt_cu12==10.7.0 tensorrt_cu12_bindings==10.7.0 tensorrt_cu12_libs==10.7.0 torch==2.6.0+cu124 torch_tensorrt==2.6.0+cu124 - CPU Architecture: Intel x86_64
- OS (e.g., Linux): Windows 10
- How you installed PyTorch (
conda,pip,libtorch, source): pip - Python version: Python 3.10.16
- CUDA version: 12.4