Skip to content

Commit 8e9d9ea

Browse files
authored
Update README.md
1 parent b662025 commit 8e9d9ea

File tree

1 file changed

+89
-2
lines changed

1 file changed

+89
-2
lines changed

README.md

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,89 @@
1-
# pytorch-custom-dataset-examples
2-
Some custom dataset tutorials for PyTorch
1+
<p align="center"><img width="40%" src="data/pytorch-logo-dark.png" /></p>
2+
3+
--------------------------------------------------------------------------------
4+
5+
# PyTorch Custom Dataset Examples
6+
7+
There are some official custom dataset examples on PyTorch repo like [this](https://github.com/pytorch/tutorials/blob/master/beginner_source/data_loading_tutorial.py) but they still seemed a bit obscure to a beginner (like me) so I had to spend some time understanding what exactly I needed to have a fully customized dataset. So here we go.
8+
9+
The first and foremost part is creating a dataset class.
10+
11+
```python
12+
from torch.utils.data.dataset import Dataset
13+
14+
class CustomDataset(Dataset):
15+
def __init__(self, a, b, c, d, transform=None):
16+
# stuff
17+
18+
def __getitem__(self, index):
19+
# stuff
20+
return (img, label)
21+
22+
def __len__(self):
23+
return count # of how many examples(images?) you have
24+
```
25+
26+
This is the skeleton that you have to fill to have a custom dataset. A dataset must contain following functions to be used by data loader afterwards.
27+
28+
* **init** function where the initial logic happens like reading a csv, assigning parameters etc.
29+
* **getitem** function where it returns a tuple of image and the label of the image. This function is called from dataloader like this:
30+
```python
31+
img, label = CustomDataset.__getitem__(99)
32+
```
33+
So, the index parameter is the **n**th image(as tensor) you are going to return.
34+
35+
* **len** function where it returns count of samples you have.
36+
37+
The first example consists of having a csv file like following(without the headers, even though it really doesn't matter), that contains file name, label(class) and an extra operation indicator. This csv file pretty much shows which image belongs to which class.
38+
39+
File Name | Label | Extra Operation |
40+
| ------------- |:-------------:| :-----:|
41+
| tr_0.png | 5 | TRUE |
42+
| tr_1.png | 0 | FALSE |
43+
| tr_1.png | 4 | FALSE |
44+
45+
If we want to build a custom dataset that reads this csv file and images from a location we can do something like following.
46+
47+
```python
48+
class CustomDatasetFromImages(Dataset):
49+
def __init__(self, csv_path, img_path, transform=None):
50+
"""
51+
Args:
52+
csv_path (string): path to csv file
53+
img_path (string): path to the folder where images are
54+
transform: pytorch transforms for transforms and tensor conversion
55+
"""
56+
# Read the csv file
57+
self.data_info = pd.read_csv(csv_path, header=None)
58+
self.img_path = img_path # Assign image path
59+
self.transform = transform # Assign transform
60+
self.labels = np.asarray(self.data_info.iloc[:, 1]) # Second column is the labels
61+
# Third column is for operation indicator
62+
self.operation = np.asarray(self.data_info.iloc[:, 2])
63+
64+
def __getitem__(self, index):
65+
# Get label(class) of the image based on the cropped pandas column
66+
single_image_label = self.labels[index]
67+
# Get image name from the pandas df
68+
single_image_name = self.data_info.iloc[index][0]
69+
# Open image
70+
img_as_img = Image.open(self.img_path + '/' + single_image_name)
71+
# If there is an operation
72+
if self.operation[index] == 'TRUE':
73+
# Do some operation on image
74+
pass
75+
# Transform image to tensor
76+
if self.transform is not None:
77+
img_as_tensor = self.transform(img_as_img)
78+
# Return image and the label
79+
return (img_as_tensor, single_image_label)
80+
81+
def __len__(self):
82+
return len(self.data_info.index)
83+
```
84+
In most of the examples, if not all, when a dataset is called, it is given a transform operation like this:
85+
```python
86+
transformations = transforms.Compose([transforms.ToTensor()])
87+
custom_mnist_from_images = CustomDatasetFromImages('path_to_csv', 'path_to_images', transformations)
88+
```
89+
transformations can contain more operations like normalize, random crop etc. The source code is [here](https://github.com/pytorch/vision/blob/master/torchvision/transforms.py).

0 commit comments

Comments
 (0)