Skip to content

🐛 [Bug] BF16 causing unspported numpy dtype error in create_constant #2902

@HolyWu

Description

@HolyWu

Bug Description

WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user) INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.bf16: 10>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False) DEBUG:torch_tensorrt.dynamo.backend.backends:Pre-AOT Autograd graph: graph(): %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_] %l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%l_x_,), kwargs = {}) return (l__self___linear,) DEBUG:torch_tensorrt.dynamo.lowering._repair_input_aliasing:Inserted auxiliary clone nodes for placeholders: graph(): %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_] %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {}) %l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%clone_default,), kwargs = {}) return (l__self___linear,) DEBUG:torch_tensorrt.dynamo.lowering._remove_sym_nodes:Removed SymInt placeholders: graph(): %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_] %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {}) %l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%clone_default,), kwargs = {}) return (l__self___linear,) DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes: graph(): %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_] %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {}) %l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%clone_default,), kwargs = {}) return (l__self___linear,) DEBUG:torch_tensorrt.dynamo.backend.backends:Post-AOT Autograd graph: graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {}) %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0] %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%_param_constant0, [1, 0]), kwargs = {}) %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1] %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %clone, %permute), kwargs = {}) return (addmm,) DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removing node clone from graph, since it is a clone node which is the only user of placeholder arg0_1 and was inserted by the compiler. DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removed auxiliary clone nodes for placeholders: graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0] %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%_param_constant0, [1, 0]), kwargs = {}) %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1] %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %permute), kwargs = {}) return (addmm,) DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding: graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0] %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1] %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %_frozen_param0), kwargs = {}) return (addmm,) DEBUG:torch_tensorrt.dynamo.backend.backends:Lowered Input graph: graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0] %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1] %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %_frozen_param0), kwargs = {}) return (addmm,) DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: Supported Nodes: - torch.ops.aten.addmm.default + Operator Count: 1 DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: All Nodes Supported DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph. WARNING:torch_tensorrt.dynamo._compiler:Node _param_constant1 of op type get_attr does not have metadata. This could sometimes lead to undefined behavior. WARNING:torch_tensorrt.dynamo._compiler:Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments. DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Number of TensorRT-Accelerated Engines Generated: 1 DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Supported Nodes: - torch.ops.aten.addmm.default + Operator Count: 1 DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: All Nodes Supported DEBUG:torch_tensorrt.dynamo._compiler:Submodule name: _run_on_acc_0 Input shapes: [(128, 20)] graph(): %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1] %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0] %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %_frozen_param0), kwargs = {}) return addmm INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +3, GPU +0, now: CPU 12984, GPU 1045 (MiB) INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +2657, GPU +308, now: CPU 15907, GPU 1353 (MiB) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: arg0_1 [shape=[128, 20], dtype=DataType.BF16] DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node addmm (kind: aten.addmm.default, args: ('<torch.Tensor as np.ndarray [shape=(30,), dtype=float32]>', 'arg0_1 <tensorrt.ITensor [shape=(128, 20), dtype=DataType.BF16]>', '<torch.Tensor as np.ndarray [shape=(20, 30), dtype=float32]>')) DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Freezing tensor addmm_constant_0 to TRT IConstantLayer Traceback (most recent call last): File "C:\Users\HolyWu\Downloads\test.py", line 29, in <module> optimized_model(*inputs) File "C:\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1552, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1561, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\eval_frame.py", line 432, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1552, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1561, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 1115, in __call__ return self._torchdynamo_orig_callable( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 947, in __call__ result = self._inner_convert( ^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 471, in __call__ return _compile( ^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_utils_internal.py", line 83, in wrapper_function return StrobelightCompileTimeProfiler.profile_compile_time( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_strobelight\compile_time_profiler.py", line 129, in profile_compile_time return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\contextlib.py", line 81, in inner return func(*args, **kwds) ^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 816, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\utils.py", line 232, in time_wrapper r = func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 635, in compile_inner out_code = transform_code_object(code, transform) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\bytecode_transformation.py", line 1184, in transform_code_object transformations(instructions, code_options) File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 177, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 581, in transform tracer.run() File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 2455, in run super().run() File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 897, in run while self.step(): ^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 809, in step self.dispatch_table[inst.opcode](self, inst) File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 2646, in RETURN_VALUE self._return(inst) File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 2631, in _return self.output.compile_subgraph( File "C:\Python312\Lib\site-packages\torch\_dynamo\output_graph.py", line 1097, in compile_subgraph self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root) File "C:\Python312\Lib\contextlib.py", line 81, in inner return func(*args, **kwds) ^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\output_graph.py", line 1314, in compile_and_call_fx_graph compiled_fn = self.call_user_compiler(gm) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\utils.py", line 232, in time_wrapper r = func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\output_graph.py", line 1405, in call_user_compiler raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( File "C:\Python312\Lib\site-packages\torch\_dynamo\output_graph.py", line 1386, in call_user_compiler compiled_fn = compiler_fn(gm, self.example_inputs()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\_dynamo\repro\after_dynamo.py", line 128, in __call__ compiled_gm = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\__init__.py", line 1989, in __call__ return self.compiler_fn(model_, inputs_, **self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\backend\backends.py", line 44, in torch_tensorrt_backend return DEFAULT_BACKEND(gm, sample_inputs, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\backend\backends.py", line 52, in aot_torch_tensorrt_aten_backend return _pretraced_backend(gm, sample_inputs, settings) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\backend\backends.py", line 108, in _pretraced_backend trt_compiled = compile_module( ^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\_compiler.py", line 412, in compile_module trt_module = convert_module( ^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_conversion.py", line 106, in convert_module interpreter_result = interpret_module_to_result(module, inputs, settings) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_conversion.py", line 87, in interpret_module_to_result interpreter_result = interpreter.run() ^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 310, in run super().run() File "C:\Python312\Lib\site-packages\torch\fx\interpreter.py", line 145, in run self.env[node] = self.run_node(node) ^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 349, in run_node trt_node: torch.fx.Node = super().run_node(n) ^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch\fx\interpreter.py", line 202, in run_node return getattr(self, n.op)(n.target, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 457, in call_function return converter(self.ctx, target, args, kwargs, self._cur_node_name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\converter_utils.py", line 469, in convert_with_type_enforcement return func(ctx, target, new_args, new_kwargs, name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\aten_ops_converters.py", line 2714, in aten_ops_addmm return impl.addmm.addmm( ^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\impl\addmm.py", line 24, in addmm mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\impl\matmul.py", line 28, in matrix_multiply other = get_trt_tensor( ^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\converter_utils.py", line 328, in get_trt_tensor return create_constant(ctx, input_val, name, dtype, min_rank) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\converter_utils.py", line 287, in create_constant value, _enums.dtype._from(dtype).to(np.dtype) if dtype is not None else None ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Python312\Lib\site-packages\torch_tensorrt\_enums.py", line 279, in to raise TypeError("Unspported numpy dtype") torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt_backend' raised: TypeError: Unspported numpy dtype While executing %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %_frozen_param0), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x000001E4A3D7F930>: ((128, 20), torch.bfloat16, False, (20, 1), torch.contiguous_format, False, {})}}) Original traceback: None  Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information   You can suppress this exception and fall back to eager by setting:  import torch._dynamo  torch._dynamo.config.suppress_errors = True

To Reproduce

import torch import torch_tensorrt class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(20, 30) def forward(self, x): return self.linear(x) device = torch.device("cuda", 0) model = MyModule().eval().to(device).bfloat16() inputs = [torch.randn((128, 20), dtype=torch.bfloat16, device=device)] with torch.inference_mode(): optimized_model = torch_tensorrt.compile( model, ir="torch_compile", inputs=inputs, enabled_precisions={torch.bfloat16}, debug=True, min_block_size=1, device=device, ) optimized_model(*inputs)

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 2.4.0.dev20240607+cu124
  • PyTorch Version (e.g. 1.0): 2.4.0.dev20240607+cu124
  • CPU Architecture: x64
  • OS (e.g., Linux): Windows 11
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.12.3
  • CUDA version: 12.4
  • GPU models and configuration: RTX 3050
  • Any other relevant information:

Additional context

Adding use_default=True argument to to(np.dtype) at

value, _enums.dtype._from(dtype).to(np.dtype) if dtype is not None else None
can make the compilation succeed. But I'm not sure if you'd like to solve it in the other way.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions