Skip to content

Commit 6ccc712

Browse files
authored
Remove addressed workaround in ResizeV2 (#7606)
1 parent 508bc1d commit 6ccc712

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,14 @@ def resize_image_tensor(
190190
if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
191191
# uint8 dtype can be included for cpu and cuda input if nearest mode
192192
acceptable_dtypes.append(torch.uint8)
193-
elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu":
193+
elif (
194+
interpolation == InterpolationMode.BILINEAR
195+
and image.device.type == "cpu"
196+
and "AVX2" in torch.backends.cpu.get_cpu_capability()
197+
):
194198
# uint8 dtype support for bilinear mode is limited to cpu and
195199
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
196-
if "AVX2" in torch.backends.cpu.get_cpu_capability():
197-
acceptable_dtypes.append(torch.uint8)
198-
199-
# TODO: Remove when https://github.com/pytorch/pytorch/pull/101136 is landed
200-
if dtype == torch.uint8 and not (
201-
image.is_contiguous() or image.is_contiguous(memory_format=torch.channels_last)
202-
):
203-
image = image.contiguous(memory_format=torch.channels_last)
200+
acceptable_dtypes.append(torch.uint8)
204201

205202
strides = image.stride()
206203
if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:

0 commit comments

Comments
 (0)