Skip to content

Commit d8cec34

Browse files
datumboxpmeier
andauthored
[prototype] Clean up and port the resize kernel in V2 (#6892)
* Ported `resize` * Align with previous behaviour * Update torchvision/prototype/transforms/functional/_geometry.py Co-authored-by: Philip Meier <github.pmeier@posteo.de> * Moving input verification on top of method. Co-authored-by: Philip Meier <github.pmeier@posteo.de>
1 parent e64784c commit d8cec34

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import PIL.Image
66
import torch
7+
from torch.nn.functional import interpolate
78
from torchvision.prototype import features
89
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
910
from torchvision.transforms.functional import (
@@ -115,20 +116,37 @@ def resize_image_tensor(
115116
max_size: Optional[int] = None,
116117
antialias: bool = False,
117118
) -> torch.Tensor:
119+
align_corners: Optional[bool] = None
120+
if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
121+
align_corners = False
122+
elif antialias:
123+
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
124+
118125
shape = image.shape
119126
num_channels, old_height, old_width = shape[-3:]
120127
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
121128

122129
if image.numel() > 0:
123130
image = image.reshape(-1, num_channels, old_height, old_width)
124131

125-
image = _FT.resize(
132+
dtype = image.dtype
133+
need_cast = dtype not in (torch.float32, torch.float64)
134+
if need_cast:
135+
image = image.to(dtype=torch.float32)
136+
137+
image = interpolate(
126138
image,
127139
size=[new_height, new_width],
128-
interpolation=interpolation.value,
140+
mode=interpolation.value,
141+
align_corners=align_corners,
129142
antialias=antialias,
130143
)
131144

145+
if need_cast:
146+
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
147+
image = image.clamp_(min=0, max=255)
148+
image = image.round_().to(dtype=dtype)
149+
132150
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
133151

134152

@@ -1312,9 +1330,11 @@ def resized_crop(
13121330

13131331
def _parse_five_crop_size(size: List[int]) -> List[int]:
13141332
if isinstance(size, numbers.Number):
1315-
size = [int(size), int(size)]
1333+
s = int(size)
1334+
size = [s, s]
13161335
elif isinstance(size, (tuple, list)) and len(size) == 1:
1317-
size = [size[0], size[0]]
1336+
s = size[0]
1337+
size = [s, s]
13181338

13191339
if len(size) != 2:
13201340
raise ValueError("Please provide only two dimensions (h, w) for size.")

0 commit comments

Comments
 (0)