Skip to content

Commit 8fdaeb0

Browse files
vfdev-5NicolasHug
andauthored
Image and Mask can accept PIL images (#7231)
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
1 parent e21e436 commit 8fdaeb0

File tree

5 files changed

+42
-2
lines changed

5 files changed

+42
-2
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def sample_inputs_rotate_video():
835835
F.rotate_bounding_box,
836836
sample_inputs_fn=sample_inputs_rotate_bounding_box,
837837
closeness_kwargs={
838-
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-6, rtol=1e-6),
838+
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
839839
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
840840
},
841841
),

test/test_prototype_datapoints.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22
import torch
3+
4+
from PIL import Image
35
from torchvision.prototype import datapoints
46

57

@@ -130,3 +132,30 @@ def test_wrap_like():
130132
assert type(label_new) is datapoints.Label
131133
assert label_new.data_ptr() == output.data_ptr()
132134
assert label_new.categories is label.categories
135+
136+
137+
@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
138+
def test_image_instance(data):
139+
image = datapoints.Image(data)
140+
assert isinstance(image, torch.Tensor)
141+
assert image.ndim == 3 and image.shape[0] == 3
142+
143+
144+
@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)])
145+
def test_mask_instance(data):
146+
mask = datapoints.Mask(data)
147+
assert isinstance(mask, torch.Tensor)
148+
assert mask.ndim == 3 and mask.shape[0] == 1
149+
150+
151+
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]])
152+
@pytest.mark.parametrize(
153+
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
154+
)
155+
def test_bbox_instance(data, format):
156+
bboxes = datapoints.BoundingBox(data, format=format, spatial_size=(32, 32))
157+
assert isinstance(bboxes, torch.Tensor)
158+
assert bboxes.ndim == 2 and bboxes.shape[1] == 4
159+
if isinstance(format, str):
160+
format = datapoints.BoundingBoxFormat.from_str(format.upper())
161+
assert bboxes.format == format

torchvision/prototype/datapoints/_dataset_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def identity(item):
9999

100100

101101
def pil_image_to_mask(pil_image):
102-
return datapoints.Mask(F.to_image_tensor(pil_image).squeeze(0))
102+
return datapoints.Mask(pil_image)
103103

104104

105105
def list_of_dicts_to_dict_of_lists(list_of_dicts):

torchvision/prototype/datapoints/_image.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ def __new__(
2323
device: Optional[Union[torch.device, str, int]] = None,
2424
requires_grad: Optional[bool] = None,
2525
) -> Image:
26+
if isinstance(data, PIL.Image.Image):
27+
from torchvision.prototype.transforms import functional as F
28+
29+
data = F.pil_to_tensor(data)
30+
2631
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
2732
if tensor.ndim < 2:
2833
raise ValueError

torchvision/prototype/datapoints/_mask.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Any, List, Optional, Tuple, Union
44

5+
import PIL.Image
56
import torch
67
from torchvision.transforms import InterpolationMode
78

@@ -21,6 +22,11 @@ def __new__(
2122
device: Optional[Union[torch.device, str, int]] = None,
2223
requires_grad: Optional[bool] = None,
2324
) -> Mask:
25+
if isinstance(data, PIL.Image.Image):
26+
from torchvision.prototype.transforms import functional as F
27+
28+
data = F.pil_to_tensor(data)
29+
2430
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
2531
return cls._wrap(tensor)
2632

0 commit comments

Comments
 (0)