DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

Caltech256 in PyTorch

Buy Me a Coffee

*Memos:

Caltech256() can use Caltech 256 dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is transform(Optional-Default:None-Type:callable).
  • The 3rd argument is target_transform(Optional-Default:None-Type:callable).
  • The 4th argument is download(Optional-Default:False-Type:bool): *Memos:
    • If it's True, the dataset is downloaded from the internet and extracted(unzipped) to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded and extracted, nothing happens.
    • It should be False if the dataset is already downloaded and extracted because it's faster.
    • You can manually download and extract the dataset(256_ObjectCategories.tar) from here to data/caltech256/.
  • About the label from the categories(classes) for the image indices, ak47(0) is 0~97, american-flag(1) is 98~194, backpack(2) is 195~345, baseball-bat(3) is 346~472, baseball-glove(4) is 473~620, basketball-hoop(5) is 621~710, bat(6) is 711~816, bathtub(7) is 817~1048, bear(8) is 1049~1150, beer-mug(9) is 1151~1244, etc.
from torchvision.datasets import Caltech256 my_data = Caltech256( root="data" ) my_data = Caltech256( root="data", transform=None, target_transform=None, download=False ) len(my_data) # 30607  my_data # Dataset Caltech256 # Number of datapoints: 30607 # Root location: data\caltech256  my_data.root # 'data/caltech256'  print(my_data.transform) # None  print(my_data.target_transform) # None  my_data.download # <bound method Caltech256.download of Dataset Caltech256 # Number of datapoints: 30607 # Root location: data\caltech256>  len(my_data.categories), my_data.categories # (257, # ['001.ak47', '002.american-flag', '003.backpack', '004.baseball-bat', # '005.baseball-glove', '006.basketball-hoop', '007.bat', '008.bathtub', # '009.bear', '010.beer-mug', '011.billiards', '012.binoculars', # ... # '254.greyhound', '255.tennis-shoes', '256.toad', '257.clutter'])  my_data[0] # (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=499x278>, 0)  my_data[1] # (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=268x218>, 0)  my_data[2] # (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x186>, 0)  my_data[98] # (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x328>, 1)  my_data[195] # (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=375x500>, 2)  import matplotlib.pyplot as plt def show_images(data, main_title=None): plt.figure(figsize=(12, 6)) plt.suptitle(t=main_title, y=1.0, fontsize=14) ims = (0, 1, 2, 98, 195, 346, 473, 621, 711, 817) for i, j in enumerate(iterable=ims, start=1): plt.subplot(2, 5, i) im, lab = data[j] plt.imshow(X=im) plt.title(label=lab) plt.tight_layout(h_pad=3.0) plt.show() show_images(data=my_data, main_title="my_data") 
Enter fullscreen mode Exit fullscreen mode

Image description

Top comments (0)