11from pathlib import Path
2- from typing import Callable , Optional , Union
2+ from typing import Any , Callable , Optional , Union
33
4- from .folder import ImageFolder
4+ from .folder import default_loader , ImageFolder
55from .utils import download_and_extract_archive , verify_str_arg
66
77
@@ -21,6 +21,7 @@ class Country211(ImageFolder):
2121 target_transform (callable, optional): A function/transform that takes in the target and transforms it.
2222 download (bool, optional): If True, downloads the dataset from the internet and puts it into
2323 ``root/country211/``. If dataset is already downloaded, it is not downloaded again.
24+ loader (callable, optional): A function to load an image given its path.
2425 """
2526
2627 _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
@@ -33,6 +34,7 @@ def __init__(
3334 transform : Optional [Callable ] = None ,
3435 target_transform : Optional [Callable ] = None ,
3536 download : bool = False ,
37+ loader : Callable [[str ], Any ] = default_loader ,
3638 ) -> None :
3739 self ._split = verify_str_arg (split , "split" , ("train" , "valid" , "test" ))
3840
@@ -46,7 +48,12 @@ def __init__(
4648 if not self ._check_exists ():
4749 raise RuntimeError ("Dataset not found. You can use download=True to download it" )
4850
49- super ().__init__ (str (self ._base_folder / self ._split ), transform = transform , target_transform = target_transform )
51+ super ().__init__ (
52+ str (self ._base_folder / self ._split ),
53+ transform = transform ,
54+ target_transform = target_transform ,
55+ loader = loader ,
56+ )
5057 self .root = str (root )
5158
5259 def _check_exists (self ) -> bool :
0 commit comments