You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
PyTorch 2.7, CUDA 12.8, TensorRT 10.9, Python 3.13
Torch-TensorRT 2.7.0 targets PyTorch 2.7, TensorRT 10.9, and CUDA 12.8, (builds for CUDA 11.8/12.4 are available via the PyTorch package index - https://download.pytorch.org/whl/cu118https://download.pytorch.org/whl/cu124). Python versions from 3.9-3.13 are supported. We no longer provide builds for the pre-cxx11-abi, all wheels and tarballs will use the cxx11 ABI.
Known Issues
Engine refitting is disabled in Python 3.13.
Using Self Defined Kernels in TensorRT Engines using Automatic Plugin Generation
Users may develop their own custom kernels using DSLs such as OpenAI Triton. Through the use of PyTorch Custom Ops and Torch-TensorRT Automatic Plugin Generation, these kernels can be called within the TensorRT engine with minimal extra code required.
@triton.jitdefelementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr): pid=tl.program_id(0) # Compute the range of elements that this thread block will work onblock_start=pid*BLOCK_SIZE# Range of indices this thread will handleoffsets=block_start+tl.arange(0, BLOCK_SIZE) # Load elements from the X and Y tensorsx_vals=tl.load(X+offsets) y_vals=tl.load(Y+offsets) # Perform the element-wise multiplicationz_vals=x_vals*y_vals*a+b# Store the result in Ztl.store(Z+offsets, z_vals) @torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=()) # type: ignore[misc]defelementwise_scale_mul( X: torch.Tensor, Y: torch.Tensor, b: float=0.2, a: int=2 ) ->torch.Tensor: # Ensure the tensors are on the GPUassertX.is_cudaandY.is_cuda, "Tensors must be on CUDA device."assertX.shape==Y.shape, "Tensors must have the same shape."# Create output tensorZ=torch.empty_like(X) # Define block sizeBLOCK_SIZE=1024# Grid of programsgrid=lambdameta: (X.numel() //meta["BLOCK_SIZE"],) # Launch the kernel with parameters a and belementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE) returnZ@torch.library.register_fake("torchtrt_ex::elementwise_scale_mul")def_(x: torch.Tensor, y: torch.Tensor, b: float=0.2, a: int=2) ->torch.Tensor: returnxtorch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True, requires_output_allocator=False) trt_mod_w_kernel=torch_tensorrt.compile(module, ...)
torch_tensorrt.dynamo.conversion.plugins.custom_op will generate a TensorRT plugin using the Quick Deploy Plugin system and using PyTorch's FakeTensor mode by reusing information required to register a Torch custom op to use with TorchDynamo. It will also generate the Torch-TensorRT converter to insert the plugin to the TensorRT engine.
QDP Plugins for Torch Custom Ops and Converters for QDP Plugins can be generated individually using
MutableTorchTensorRTModule automatically recompiles if the engine becomes invalid. Previously, engines would assume static shape which means that if a user provides a different sized input, the graph would recompile or pull from engine cache. Now developers are able to provide shape hints to the MutableTorchTensorRTModule which will allow the module to handle a broader range of inputs without recompiling. For example:
For networks that produce outputs whose shapes are dependent on the shape of the input, the output buffer must be allocated at runtime. To support this use case we have added a new runtime mode Dynamic Output Allocation Mode to support Data Dependent Shape (DDS) operations, such as NonZero op. (#3388)
Note:
Dynamic output allocation mode cannot be used in conjunction with CUDA Graphs nor pre-allocated outputs feature.
Without dynamic output allocation, the output buffer is allocated based on the inferred output shape based on input size.
There are two scenarios in which dynamic output allocation is enabled:
The model has been identified at compile time to require dynamic output allocation for at least one TensorRT subgraph. These models will engage the runtime mode automatically (with logging) and are incompatible with other runtime modes such as CUDA Graphs. Converters can declare that subgraphs that they produce will require the output allocator using requires_output_allocator=True there by forcing any model which utilizes the converter to automatically use the output allocator runtime mode. e.g.,
Users may manually enable dynamic output allocation mode via the torch_tensorrt.runtime.enable_output_allocator context manager.
# Enables Dynamic Output Allocation Mode, then resets the mode to its prior settingwithtorch_tensorrt.runtime.enable_output_allocator(trt_module): ...
Tiling Optimization support
Tiling optimization enables cross-kernel tiled inference. This technique leverages on-chip caching for continuous kernels in addition to kernel-level tiling. It can significantly enhance performance on platforms constrained by memory bandwidth. (#3444)
We currently support four tiling strategies "none", "fast", "moderate", "full". A higher level allows TensorRT to spend more time searching for better tiling strategy. Here's an example to call tiling optimization:
Added support for compiling the FLUX.1-dev 12B model in our model zoo. An example is available here. Quantized variants of FLUX are under development as part of future work.
General Improvements
Improved BF16 support in model compilation by fixing bugs and adding new tests to cover both full-graph and graph-break scenarios.
We added support for Python 3.13 (#3455). However, due to the Python object reference issue in PyTorch 2.7, we disabled the refitting related features for Python 3.13 in this release. This issue should be fixed in the next release.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
PyTorch 2.7, CUDA 12.8, TensorRT 10.9, Python 3.13
Torch-TensorRT 2.7.0 targets PyTorch 2.7, TensorRT 10.9, and CUDA 12.8, (builds for CUDA 11.8/12.4 are available via the PyTorch package index - https://download.pytorch.org/whl/cu118 https://download.pytorch.org/whl/cu124). Python versions from 3.9-3.13 are supported. We no longer provide builds for the pre-cxx11-abi, all wheels and tarballs will use the cxx11 ABI.
Known Issues
Using Self Defined Kernels in TensorRT Engines using Automatic Plugin Generation
Users may develop their own custom kernels using DSLs such as OpenAI Triton. Through the use of PyTorch Custom Ops and Torch-TensorRT Automatic Plugin Generation, these kernels can be called within the TensorRT engine with minimal extra code required.
torch_tensorrt.dynamo.conversion.plugins.custom_opwill generate a TensorRT plugin using the Quick Deploy Plugin system and using PyTorch's FakeTensor mode by reusing information required to register a Torch custom op to use with TorchDynamo. It will also generate the Torch-TensorRT converter to insert the plugin to the TensorRT engine.QDP Plugins for Torch Custom Ops and Converters for QDP Plugins can be generated individually using
MutableTorchTensorRTModule improvements
MutableTorchTensorRTModuleautomatically recompiles if the engine becomes invalid. Previously, engines would assume static shape which means that if a user provides a different sized input, the graph would recompile or pull from engine cache. Now developers are able to provide shape hints to theMutableTorchTensorRTModulewhich will allow the module to handle a broader range of inputs without recompiling. For example:Data Dependent Shape support
For networks that produce outputs whose shapes are dependent on the shape of the input, the output buffer must be allocated at runtime. To support this use case we have added a new runtime mode Dynamic Output Allocation Mode to support Data Dependent Shape (DDS) operations, such as NonZero op. (#3388)
Note:
There are two scenarios in which dynamic output allocation is enabled:
requires_output_allocator=Truethere by forcing any model which utilizes the converter to automatically use the output allocator runtime mode. e.g.,torch_tensorrt.runtime.enable_output_allocatorcontext manager.Tiling Optimization support
Tiling optimization enables cross-kernel tiled inference. This technique leverages on-chip caching for continuous kernels in addition to kernel-level tiling. It can significantly enhance performance on platforms constrained by memory bandwidth. (#3444)
We currently support four tiling strategies "none", "fast", "moderate", "full". A higher level allows TensorRT to spend more time searching for better tiling strategy. Here's an example to call tiling optimization:
Model Zoo additions
General Improvements
Python 3.13 support
We added support for Python 3.13 (#3455). However, due to the Python object reference issue in PyTorch 2.7, we disabled the refitting related features for Python 3.13 in this release. This issue should be fixed in the next release.
What's Changed
--use_python_runtimeand--enable_cuda_graphargs to the perf run script by @zewenli98 in feat: add--use_python_runtimeand--enable_cuda_graphargs to the perf run script #3397New Contributors
Full Changelog: v2.6.0...v2.7.0
This discussion was created from the release Torch-TensorRT v2.7.0.
Beta Was this translation helpful? Give feedback.
All reactions