Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import tensorrt as trt
import torch
from torch.fx.node import Target

from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
Expand Down Expand Up @@ -45,6 +44,8 @@ def convNd(
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution."

num_dims = len(input.shape) - 2

if is_conv1d:
# Apply an unsqueeze operation to transform the conv1d problem into conv2d
input = impl.unsqueeze.unsqueeze(
Expand Down Expand Up @@ -104,7 +105,12 @@ def convNd(
conv_layer.set_input(2, bias)

# Cast certain fields to tuples, in accordance with TRT requirements
padding = (padding,) if isinstance(padding, int) else padding
if isinstance(padding, int):
padding = (padding,) * num_dims
elif isinstance(padding, (list, tuple)):
padding = tuple(padding)
if len(padding) == 1:
padding = (padding[0],) * num_dims
stride = (stride,) if isinstance(stride, int) else stride
dilation = (dilation,) if isinstance(dilation, int) else dilation

Expand Down
5 changes: 4 additions & 1 deletion tests/py/dynamo/conversion/test_convolution_aten.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests

from torch_tensorrt import Input

from .harness import DispatchTestCase
Expand Down Expand Up @@ -134,6 +133,8 @@ def forward(self, x):
param("no_bias", 1, bias=False),
("tuple_parameters", 1, (1, 1), (1, 1)),
param("non_zero_padding", 1, padding=1),
param("list_zero_padding", 1, padding=[0]),
param("list_non_padding", 1, padding=[1]),
param("dilation", 1, dilation=2),
param("groups", 1, groups=3),
]
Expand Down Expand Up @@ -205,6 +206,8 @@ def forward(self, x):
param("no_bias", 1, bias=False),
("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)),
param("non_zero_padding", 1, padding=1),
param("list_zero_padding", 1, padding=[0]),
param("list_non_padding", 1, padding=[1]),
param("dilation", 1, dilation=2),
## TODO TRT 8.4.1 will trigger issue with this test. T127981773
# param("groups", 1, groups=3),
Expand Down