|
1 | 1 | import pytest |
2 | 2 | import torch |
| 3 | + |
| 4 | +from PIL import Image |
3 | 5 | from torchvision.prototype import datapoints |
4 | 6 |
|
5 | 7 |
|
@@ -130,3 +132,30 @@ def test_wrap_like(): |
130 | 132 | assert type(label_new) is datapoints.Label |
131 | 133 | assert label_new.data_ptr() == output.data_ptr() |
132 | 134 | assert label_new.categories is label.categories |
| 135 | + |
| 136 | + |
| 137 | +@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)]) |
| 138 | +def test_image_instance(data): |
| 139 | + image = datapoints.Image(data) |
| 140 | + assert isinstance(image, torch.Tensor) |
| 141 | + assert image.ndim == 3 and image.shape[0] == 3 |
| 142 | + |
| 143 | + |
| 144 | +@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)]) |
| 145 | +def test_mask_instance(data): |
| 146 | + mask = datapoints.Mask(data) |
| 147 | + assert isinstance(mask, torch.Tensor) |
| 148 | + assert mask.ndim == 3 and mask.shape[0] == 1 |
| 149 | + |
| 150 | + |
| 151 | +@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]]) |
| 152 | +@pytest.mark.parametrize( |
| 153 | + "format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH] |
| 154 | +) |
| 155 | +def test_bbox_instance(data, format): |
| 156 | + bboxes = datapoints.BoundingBox(data, format=format, spatial_size=(32, 32)) |
| 157 | + assert isinstance(bboxes, torch.Tensor) |
| 158 | + assert bboxes.ndim == 2 and bboxes.shape[1] == 4 |
| 159 | + if isinstance(format, str): |
| 160 | + format = datapoints.BoundingBoxFormat.from_str(format.upper()) |
| 161 | + assert bboxes.format == format |
0 commit comments