Skip to content

Commit 255526a

Browse files
authored
Update README.md
1 parent 6395480 commit 255526a

File tree

1 file changed

+18
-37
lines changed

1 file changed

+18
-37
lines changed

README.md

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,20 @@ class MyCustomDataset(Dataset):
3232
return count # of how many examples(images?) you have
3333
```
3434

35-
This is the skeleton that you have to fill to have a custom dataset. A dataset must contain following functions to be used by data loader afterwards.
35+
This is the skeleton that you have to fill to have a custom dataset. A dataset must contain following functions to be used by data loader later on.
3636

37-
* **init** function where the initial logic happens like reading a csv, assigning parameters etc.
38-
* **getitem** function where it returns a tuple of image and the label of the image. This function is called from dataloader like this:
37+
* `__init__()` function is where the initial logic happens like reading a csv, assigning transforms etc.
38+
* `__getitem__()` function returns the data and labels. This function is called from dataloader like this:
3939
```python
40-
img, label = CustomDataset.__getitem__(99)
40+
img, label = CustomDataset.__getitem__(99) # For 99th item
4141
```
42-
So, the index parameter is the **n**th image(as tensor, numpy whatever you want) you are going to return.
42+
So, the index parameter is the **n**th data/image (as tensor) you are going to return.
4343

44-
* **len** function where it returns count of samples you have.
44+
* `__len__()` returns count of samples you have.
4545

46+
An important thing to note is that `__getitem__()` return a specific type for a single data point (like a tensor, numpy array etc.), otherwise, in the data loader you will get an error like:
47+
48+
TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.PngImagePlugin.PngImageFile'>
4649

4750
## Using Torchvision Transforms
4851
In most of the examples you see `transforms = None` in the `__init__()`, this is used to apply torchvision transforms to your data/image. You can find the extensive list of the transforms [here](http://pytorch.org/docs/0.2.0/torchvision/transforms.html) and [here](https://github.com/pytorch/vision/blob/master/torchvision/transforms/transforms.py). The most common usage of transforms is like this:
@@ -127,7 +130,7 @@ if __name__ == '__main__':
127130

128131
## Incorporating Pandas
129132

130-
Let's say we want to read some data from a csv with pandas. The first example is of having a csv file like following (without the headers, even though it really doesn't matter), that contains file name, label(class) and an extra operation indicator.
133+
Let's say we want to read some data from a csv with pandas. The first example is of having a csv file like following (without the headers, even though it really doesn't matter), that contains file name, label(class) and an extra operation indicator and depending on this extra operation flag we do some operation on the image.
131134

132135
File Name | Label | Extra Operation |
133136
| ------------- |:-------------:| :-----:|
@@ -183,40 +186,16 @@ class CustomDatasetFromImages(Dataset):
183186

184187
def __len__(self):
185188
return self.data_len
186-
```
187-
188-
## Incorporating Pandas
189189

190-
In most of the examples, if not all, when a dataset is called, it is given a transform operation like this:
191-
```python
192-
transformations = transforms.Compose([transforms.ToTensor()])
193-
custom_mnist_from_images = CustomDatasetFromImages('path_to_csv', 'path_to_images', transformations)
190+
if __name__ == "__main__":
191+
# Call dataset
192+
custom_mnist_from_images = \
193+
CustomDatasetFromImages('../data/mnist_labels.csv')
194194
```
195-
transforms can contain more operations like normalize, random crop etc. The source code is [here](https://github.com/pytorch/vision/blob/master/torchvision/transforms.py). But at the end, it will probably contain transforms.ToTensor(). This operation turns the PIL images to tensors so that you can feed it to models. Thats why just before returning the tuple in **__getitem__** we do:
196-
```python
197-
# Transform image to tensor
198-
if self.transform is not None:
199-
img_as_tensor = self.transform(img_as_img)
200-
```
201-
Also, ToTensor can convert both grayscale and RGB images so you don't have to worry about how many channels you have for images.
202195

203-
**__len__** function is supposed to return the amount of images(or samples) you have, since we read the csv to pandas df at the beginning we can get the amount of samples as `len(self.data_info.index)`.
196+
## Incorporating Pandas - Puttng More Stuff in in `__getitem__()`
204197

205-
You can also make use of other columns in csv to do some operations, you just have to read the column and operate on image if there is an operation.
206-
```python
207-
...
208-
# Third column is for operation indicator
209-
self.operation = np.asarray(self.data_info.iloc[:, 2])
210-
...
211-
212-
...
213-
if self.operation[index] == 'TRUE':
214-
# Do some operation
215-
...
216-
217-
```
218-
219-
Yet another example might be reading an image from CSV where the value of each pixel is listed in a column. This just changes the parameters we take as an input and the logic in **__getitem__**. In the end, you just return images as tensors and their labels.
198+
Yet another example might be reading an image from CSV where the value of each pixel is listed in a column. (Sometimes MNIST is given this way). This just changes the logic in `__getitem__()`. In the end, you just return images as tensors and their labels.
220199

221200
```python
222201
class CustomDatasetFromCSV(Dataset):
@@ -256,4 +235,6 @@ class CustomDatasetFromCSV(Dataset):
256235
return len(self.data.index)
257236
```
258237

238+
239+
259240
I will continue updating this repo if I do some fancy stuff in the future that is different than these examples.

0 commit comments

Comments
 (0)