Skip to content

Commit 23e0d65

Browse files
uridahsoumith
authored andcommitted
SVHN dataset for torchvision (pytorch#98)
1 parent c4f4c73 commit 23e0d65

File tree

3 files changed

+124
-1
lines changed

3 files changed

+124
-1
lines changed

README.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,17 @@ STL10
168168
- ``download`` : ``True`` = downloads the dataset from the internet and
169169
puts it in root directory. If dataset already downloaded, does not do
170170
anything.
171+
172+
SVHN
173+
~~~~~
174+
175+
``dset.SVHN(root, split='train', transform=None, target_transform=None, download=False)``
176+
177+
- ``root`` : root directory of dataset where there is folder ``SVHN``
178+
- ``split`` : ``'train'`` = Training set, ``'test'`` = Test set, ``'extra'`` = Extra training set
179+
- ``download`` : ``True`` = downloads the dataset from the internet and
180+
puts it in root directory. If dataset is already downloaded, does not do
181+
anything.
171182

172183
ImageFolder
173184
~~~~~~~~~~~

torchvision/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from .cifar import CIFAR10, CIFAR100
55
from .stl10 import STL10
66
from .mnist import MNIST
7+
from .svhn import SVHN
78

89
__all__ = ('LSUN', 'LSUNClass',
910
'ImageFolder',
1011
'CocoCaptions', 'CocoDetection',
1112
'CIFAR10', 'CIFAR100',
12-
'MNIST', 'STL10')
13+
'MNIST', 'STL10', 'SVHN')

torchvision/datasets/svhn.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from __future__ import print_function
2+
import torch.utils.data as data
3+
from PIL import Image
4+
import os
5+
import os.path
6+
import errno
7+
import numpy as np
8+
import sys
9+
10+
11+
class SVHN(data.Dataset):
12+
url = ""
13+
filename = ""
14+
file_md5 = ""
15+
16+
split_list = {
17+
'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
18+
"train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
19+
'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
20+
"test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
21+
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
22+
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
23+
24+
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
25+
self.root = root
26+
self.transform = transform
27+
self.target_transform = target_transform
28+
self.split = split # training set or test set or extra set
29+
30+
if self.split not in self.split_list:
31+
raise ValueError('Wrong split entered! Please use split="train" or split="extra" or split="test"')
32+
33+
self.url = self.split_list[split][0]
34+
self.filename = self.split_list[split][1]
35+
self.file_md5 = self.split_list[split][2]
36+
37+
if download:
38+
self.download()
39+
40+
if not self._check_integrity():
41+
raise RuntimeError('Dataset not found or corrupted.' +
42+
' You can use download=True to download it')
43+
44+
# import here rather than at top of file because this is
45+
# an optional dependency for torchvision
46+
import scipy.io as sio
47+
48+
# reading(loading) mat file as array
49+
loaded_mat = sio.loadmat(os.path.join(root, self.filename))
50+
51+
self.data = loaded_mat['X']
52+
self.labels = loaded_mat['y']
53+
self.data = np.transpose(self.data, (3, 2, 0, 1))
54+
55+
def __getitem__(self, index):
56+
img, target = self.data[index], self.labels[index]
57+
58+
# doing this so that it is consistent with all other datasets
59+
# to return a PIL Image
60+
img = Image.fromarray(np.transpose(img, (1, 2, 0)))
61+
62+
if self.transform is not None:
63+
img = self.transform(img)
64+
65+
if self.target_transform is not None:
66+
target = self.target_transform(target)
67+
68+
return img, target
69+
70+
def __len__(self):
71+
return len(self.data)
72+
73+
def _check_integrity(self):
74+
import hashlib
75+
root = self.root
76+
md5 = self.split_list[self.split][2]
77+
fpath = os.path.join(root, self.filename)
78+
if not os.path.isfile(fpath):
79+
return False
80+
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
81+
if md5c != md5:
82+
return False
83+
return True
84+
85+
def download(self):
86+
from six.moves import urllib
87+
import tarfile
88+
import hashlib
89+
90+
root = self.root
91+
fpath = os.path.join(root, self.filename)
92+
93+
try:
94+
os.makedirs(root)
95+
except OSError as e:
96+
if e.errno == errno.EEXIST:
97+
pass
98+
else:
99+
raise
100+
101+
if self._check_integrity():
102+
print('Files already downloaded and verified')
103+
return
104+
105+
# downloads file
106+
if os.path.isfile(fpath):
107+
print('Using downloaded file: ' + fpath)
108+
else:
109+
print('Downloading ' + self.url + ' to ' + fpath)
110+
urllib.request.urlretrieve(self.url, fpath)
111+
print ('Downloaded!')

0 commit comments

Comments
 (0)