|
1 |
| -# pytorch-custom-dataset-examples |
2 |
| -Some custom dataset tutorials for PyTorch |
| 1 | +<p align="center"><img width="40%" src="data/pytorch-logo-dark.png" /></p> |
| 2 | + |
| 3 | +-------------------------------------------------------------------------------- |
| 4 | + |
| 5 | +# PyTorch Custom Dataset Examples |
| 6 | + |
| 7 | +There are some official custom dataset examples on PyTorch repo like [this](https://github.com/pytorch/tutorials/blob/master/beginner_source/data_loading_tutorial.py) but they still seemed a bit obscure to a beginner (like me) so I had to spend some time understanding what exactly I needed to have a fully customized dataset. So here we go. |
| 8 | + |
| 9 | +The first and foremost part is creating a dataset class. |
| 10 | + |
| 11 | +```python |
| 12 | +from torch.utils.data.dataset import Dataset |
| 13 | + |
| 14 | +class CustomDataset(Dataset): |
| 15 | + def __init__(self, a, b, c, d, transform=None): |
| 16 | + # stuff |
| 17 | + |
| 18 | + def __getitem__(self, index): |
| 19 | + # stuff |
| 20 | + return (img, label) |
| 21 | + |
| 22 | + def __len__(self): |
| 23 | + return count # of how many examples(images?) you have |
| 24 | +``` |
| 25 | + |
| 26 | +This is the skeleton that you have to fill to have a custom dataset. A dataset must contain following functions to be used by data loader afterwards. |
| 27 | + |
| 28 | +* **init** function where the initial logic happens like reading a csv, assigning parameters etc. |
| 29 | +* **getitem** function where it returns a tuple of image and the label of the image. This function is called from dataloader like this: |
| 30 | +```python |
| 31 | +img, label = CustomDataset.__getitem__(99) |
| 32 | +``` |
| 33 | +So, the index parameter is the **n**th image(as tensor) you are going to return. |
| 34 | + |
| 35 | +* **len** function where it returns count of samples you have. |
| 36 | + |
| 37 | +The first example consists of having a csv file like following(without the headers, even though it really doesn't matter), that contains file name, label(class) and an extra operation indicator. This csv file pretty much shows which image belongs to which class. |
| 38 | + |
| 39 | + File Name | Label | Extra Operation | |
| 40 | +| ------------- |:-------------:| :-----:| |
| 41 | +| tr_0.png | 5 | TRUE | |
| 42 | +| tr_1.png | 0 | FALSE | |
| 43 | +| tr_1.png | 4 | FALSE | |
| 44 | + |
| 45 | +If we want to build a custom dataset that reads this csv file and images from a location we can do something like following. |
| 46 | + |
| 47 | +```python |
| 48 | +class CustomDatasetFromImages(Dataset): |
| 49 | + def __init__(self, csv_path, img_path, transform=None): |
| 50 | + """ |
| 51 | + Args: |
| 52 | + csv_path (string): path to csv file |
| 53 | + img_path (string): path to the folder where images are |
| 54 | + transform: pytorch transforms for transforms and tensor conversion |
| 55 | + """ |
| 56 | + # Read the csv file |
| 57 | + self.data_info = pd.read_csv(csv_path, header=None) |
| 58 | + self.img_path = img_path # Assign image path |
| 59 | + self.transform = transform # Assign transform |
| 60 | + self.labels = np.asarray(self.data_info.iloc[:, 1]) # Second column is the labels |
| 61 | + # Third column is for operation indicator |
| 62 | + self.operation = np.asarray(self.data_info.iloc[:, 2]) |
| 63 | + |
| 64 | + def __getitem__(self, index): |
| 65 | + # Get label(class) of the image based on the cropped pandas column |
| 66 | + single_image_label = self.labels[index] |
| 67 | + # Get image name from the pandas df |
| 68 | + single_image_name = self.data_info.iloc[index][0] |
| 69 | + # Open image |
| 70 | + img_as_img = Image.open(self.img_path + '/' + single_image_name) |
| 71 | + # If there is an operation |
| 72 | + if self.operation[index] == 'TRUE': |
| 73 | + # Do some operation on image |
| 74 | + pass |
| 75 | + # Transform image to tensor |
| 76 | + if self.transform is not None: |
| 77 | + img_as_tensor = self.transform(img_as_img) |
| 78 | + # Return image and the label |
| 79 | + return (img_as_tensor, single_image_label) |
| 80 | + |
| 81 | + def __len__(self): |
| 82 | + return len(self.data_info.index) |
| 83 | +``` |
| 84 | +In most of the examples, if not all, when a dataset is called, it is given a transform operation like this: |
| 85 | +```python |
| 86 | +transformations = transforms.Compose([transforms.ToTensor()]) |
| 87 | +custom_mnist_from_images = CustomDatasetFromImages('path_to_csv', 'path_to_images', transformations) |
| 88 | +``` |
| 89 | +transformations can contain more operations like normalize, random crop etc. The source code is [here](https://github.com/pytorch/vision/blob/master/torchvision/transforms.py). |
0 commit comments