|
4 | 4 |
|
5 | 5 | import PIL.Image |
6 | 6 | import torch |
| 7 | +from torch.nn.functional import interpolate |
7 | 8 | from torchvision.prototype import features |
8 | 9 | from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT |
9 | 10 | from torchvision.transforms.functional import ( |
@@ -115,20 +116,37 @@ def resize_image_tensor( |
115 | 116 | max_size: Optional[int] = None, |
116 | 117 | antialias: bool = False, |
117 | 118 | ) -> torch.Tensor: |
| 119 | + align_corners: Optional[bool] = None |
| 120 | + if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC: |
| 121 | + align_corners = False |
| 122 | + elif antialias: |
| 123 | + raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") |
| 124 | + |
118 | 125 | shape = image.shape |
119 | 126 | num_channels, old_height, old_width = shape[-3:] |
120 | 127 | new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) |
121 | 128 |
|
122 | 129 | if image.numel() > 0: |
123 | 130 | image = image.reshape(-1, num_channels, old_height, old_width) |
124 | 131 |
|
125 | | - image = _FT.resize( |
| 132 | + dtype = image.dtype |
| 133 | + need_cast = dtype not in (torch.float32, torch.float64) |
| 134 | + if need_cast: |
| 135 | + image = image.to(dtype=torch.float32) |
| 136 | + |
| 137 | + image = interpolate( |
126 | 138 | image, |
127 | 139 | size=[new_height, new_width], |
128 | | - interpolation=interpolation.value, |
| 140 | + mode=interpolation.value, |
| 141 | + align_corners=align_corners, |
129 | 142 | antialias=antialias, |
130 | 143 | ) |
131 | 144 |
|
| 145 | + if need_cast: |
| 146 | + if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8: |
| 147 | + image = image.clamp_(min=0, max=255) |
| 148 | + image = image.round_().to(dtype=dtype) |
| 149 | + |
132 | 150 | return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) |
133 | 151 |
|
134 | 152 |
|
@@ -1312,9 +1330,11 @@ def resized_crop( |
1312 | 1330 |
|
1313 | 1331 | def _parse_five_crop_size(size: List[int]) -> List[int]: |
1314 | 1332 | if isinstance(size, numbers.Number): |
1315 | | - size = [int(size), int(size)] |
| 1333 | + s = int(size) |
| 1334 | + size = [s, s] |
1316 | 1335 | elif isinstance(size, (tuple, list)) and len(size) == 1: |
1317 | | - size = [size[0], size[0]] |
| 1336 | + s = size[0] |
| 1337 | + size = [s, s] |
1318 | 1338 |
|
1319 | 1339 | if len(size) != 2: |
1320 | 1340 | raise ValueError("Please provide only two dimensions (h, w) for size.") |
|
0 commit comments