Skip to content
37 changes: 37 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.nn.functional as F
from common_utils import combinations_grid
from torchvision import datasets
from torchvision.io import decode_image
from torchvision.transforms import v2


Expand Down Expand Up @@ -1175,6 +1176,8 @@ class SBUTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SBU
FEATURE_TYPES = (PIL.Image.Image, str)

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir, config):
num_images = 3

Expand Down Expand Up @@ -1413,6 +1416,8 @@ class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
_IMAGES_FOLDER = "images"
_ANNOTATIONS_FILE = "captions.html"

SUPPORT_TV_IMAGE_DECODE = True

def dataset_args(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)
root = tmpdir / self._IMAGES_FOLDER
Expand Down Expand Up @@ -1482,6 +1487,8 @@ class Flickr30kTestCase(Flickr8kTestCase):

_ANNOTATIONS_FILE = "captions.token"

SUPPORT_TV_IMAGE_DECODE = True

def _image_file_name(self, idx):
return f"{idx}.jpg"

Expand Down Expand Up @@ -1942,6 +1949,8 @@ class LFWPeopleTestCase(datasets_utils.DatasetTestCase):
_IMAGES_DIR = {"original": "lfw", "funneled": "lfw_funneled", "deepfunneled": "lfw-deepfunneled"}
_file_id = {"10fold": "", "train": "DevTrain", "test": "DevTest"}

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / "lfw-py"
os.makedirs(tmpdir, exist_ok=True)
Expand Down Expand Up @@ -1978,6 +1987,18 @@ def _create_random_id(self):
part2 = datasets_utils.create_random_string(random.randint(4, 7))
return f"{part1}_{part2}"

def test_tv_decode_image_support(self):
if not self.SUPPORT_TV_IMAGE_DECODE:
pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.")

with self.create_dataset(
config=dict(
loader=decode_image,
)
) as (dataset, _):
image = dataset[0][0]
assert isinstance(image, torch.Tensor)


class LFWPairsTestCase(LFWPeopleTestCase):
DATASET_CLASS = datasets.LFWPairs
Expand Down Expand Up @@ -2335,6 +2356,8 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase):

ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir: str, config):
root_folder = pathlib.Path(tmpdir) / "food-101"
image_folder = root_folder / "images"
Expand Down Expand Up @@ -2371,6 +2394,7 @@ class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer")
)
SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir: str, config):
split = config["split"]
Expand Down Expand Up @@ -2420,6 +2444,8 @@ def inject_fake_data(self, tmpdir: str, config):
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SUN397

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir: str, config):
data_dir = pathlib.Path(tmpdir) / "SUN397"
data_dir.mkdir()
Expand Down Expand Up @@ -2451,6 +2477,8 @@ class DTDTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DTD
FEATURE_TYPES = (PIL.Image.Image, int)

SUPPORT_TV_IMAGE_DECODE = True

ADDITIONAL_CONFIGS = combinations_grid(
split=("train", "test", "val"),
# There is no need to test the whole matrix here, since each fold is treated exactly the same
Expand Down Expand Up @@ -2611,6 +2639,7 @@ class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase):
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))

ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir, config):
data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0"
Expand Down Expand Up @@ -2708,6 +2737,8 @@ class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
REQUIRED_PACKAGES = ("scipy",)
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "test"))

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir, config):
import scipy.io as io
from numpy.core.records import fromarrays
Expand Down Expand Up @@ -2782,6 +2813,8 @@ class Flowers102TestCase(datasets_utils.ImageDatasetTestCase):
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
REQUIRED_PACKAGES = ("scipy",)

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir: str, config):
base_folder = pathlib.Path(tmpdir) / "flowers-102"

Expand Down Expand Up @@ -2840,6 +2873,8 @@ class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val", "test"))
SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"}

SUPPORT_TV_IMAGE_DECODE = True

def inject_fake_data(self, tmpdir: str, config):
root_folder = pathlib.Path(tmpdir) / "rendered-sst2"
image_folder = root_folder / self.SPLIT_TO_FOLDER[config["split"]]
Expand Down Expand Up @@ -3500,6 +3535,8 @@ class ImagenetteTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Imagenette
ADDITIONAL_CONFIGS = combinations_grid(split=["train", "val"], size=["full", "320px", "160px"])

SUPPORT_TV_IMAGE_DECODE = True

_WNIDS = [
"n01440764",
"n02102040",
Expand Down
13 changes: 9 additions & 4 deletions torchvision/datasets/clevr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, List, Optional, Tuple, Union
from urllib.parse import urlparse

from PIL import Image
from .folder import default_loader

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
Expand All @@ -18,11 +18,14 @@ class CLEVRClassification(VisionDataset):
root (str or ``pathlib.Path``): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
set to True.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in them target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
dataset is already downloaded, it is not downloaded again.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""

_URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
Expand All @@ -35,9 +38,11 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, transform=transform, target_transform=target_transform)
self.loader = loader
self._base_folder = pathlib.Path(self.root) / "clevr"
self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem

Expand Down Expand Up @@ -65,7 +70,7 @@ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file = self._image_files[idx]
label = self._labels[idx]

image = Image.open(image_file).convert("RGB")
image = self.loader(image_file)

if self.transform:
image = self.transform(image)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/country211.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class Country211(ImageFolder):
Args:
root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and puts it into
``root/country211/``. If dataset is already downloaded, it is not downloaded again.
Expand Down
13 changes: 9 additions & 4 deletions torchvision/datasets/dtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pathlib
from typing import Any, Callable, Optional, Tuple, Union

import PIL.Image
from .folder import default_loader

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
Expand All @@ -21,12 +21,15 @@ class DTD(VisionDataset):
The partition only changes which split each image belongs to. Thus, regardless of the selected
partition, combining all splits will result in all images.

transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""

_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
Expand All @@ -40,6 +43,7 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
loader: Callable[[Union[str, pathlib.Path]], Any] = default_loader,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
if not isinstance(partition, int) and not (1 <= partition <= 10):
Expand Down Expand Up @@ -72,13 +76,14 @@ def __init__(
self.classes = sorted(set(classes))
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
self._labels = [self.class_to_idx[cls] for cls in classes]
self.loader = loader

def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
image = self.loader(image_file)

if self.transform:
image = self.transform(image)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/datasets/eurosat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class EuroSAT(ImageFolder):

Args:
root (str or ``pathlib.Path``): Root directory of dataset where ``root/eurosat`` exists.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
Expand Down
11 changes: 8 additions & 3 deletions torchvision/datasets/fgvc_aircraft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union

import PIL.Image
from .folder import default_loader

from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
Expand All @@ -29,13 +29,16 @@ class FGVCAircraft(VisionDataset):
``trainval`` and ``test``.
annotation_level (str, optional): The annotation level, supports ``variant``,
``family`` and ``manufacturer``.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image or torch.Tensor, depends on the given loader,
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
loader (callable, optional): A function to load an image given its path.
By default, it uses PIL as its image loader, but users could also pass in
``torchvision.io.decode_image`` for decoding image data into tensors directly.
"""

_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
Expand All @@ -48,6 +51,7 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
loader: Callable[[str], Any] = default_loader,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
Expand Down Expand Up @@ -87,13 +91,14 @@ def __init__(
image_name, label_name = line.strip().split(" ", 1)
self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
self._labels.append(self.class_to_idx[label_name])
self.loader = loader

def __len__(self) -> int:
return len(self._image_files)

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
image = self.loader(image_file)

if self.transform:
image = self.transform(image)
Expand Down
Loading