Skip to content

Commit 3d7a613

Browse files
Merge pull request #2242 from wanghaoshuang/flowers_reader
Add flowers dataset for image classification model
2 parents 2c7d8e7 + 990b7d7 commit 3d7a613

File tree

4 files changed

+399
-7
lines changed

4 files changed

+399
-7
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
This module will download dataset from
16+
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
17+
and parse train/test set intopaddle reader creators.
18+
19+
This set contains images of flowers belonging to 102 different categories.
20+
The images were acquired by searching the web and taking pictures. There are a
21+
minimum of 40 images for each category.
22+
23+
The database was used in:
24+
25+
Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
26+
number of classes.Proceedings of the Indian Conference on Computer Vision,
27+
Graphics and Image Processing (2008)
28+
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
29+
30+
"""
31+
import cPickle
32+
import itertools
33+
from common import download
34+
import tarfile
35+
import scipy.io as scio
36+
from paddle.v2.image import *
37+
import os
38+
import numpy as np
39+
import paddle.v2 as paddle
40+
from multiprocessing import cpu_count
41+
__all__ = ['train', 'test', 'valid']
42+
43+
DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
44+
LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
45+
SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
46+
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
47+
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
48+
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
49+
50+
51+
def default_mapper(sample):
52+
'''
53+
map image bytes data to type needed by model input layer
54+
'''
55+
img, label = sample
56+
img = paddle.image.load_image_bytes(img)
57+
img = paddle.image.simple_transform(img, 256, 224, True)
58+
return img.flatten().astype('float32'), label
59+
60+
61+
def reader_creator(data_file,
62+
label_file,
63+
setid_file,
64+
dataset_name,
65+
mapper=default_mapper,
66+
buffered_size=1024):
67+
'''
68+
1. read images from tar file and
69+
merge images into batch files in 102flowers.tgz_batch/
70+
2. get a reader to read sample from batch file
71+
72+
:param data_file: downloaded data file
73+
:type data_file: string
74+
:param label_file: downloaded label file
75+
:type label_file: string
76+
:param setid_file: downloaded setid file containing information
77+
about how to split dataset
78+
:type setid_file: string
79+
:param dataset_name: data set name (tstid|trnid|valid)
80+
:type dataset_name: string
81+
:param mapper: a function to map image bytes data to type
82+
needed by model input layer
83+
:type mapper: callable
84+
:param buffered_size: the size of buffer used to process images
85+
:type buffered_size: int
86+
:return: data reader
87+
:rtype: callable
88+
'''
89+
labels = scio.loadmat(label_file)['labels'][0]
90+
indexes = scio.loadmat(setid_file)[dataset_name][0]
91+
img2label = {}
92+
for i in indexes:
93+
img = "jpg/image_%05d.jpg" % i
94+
img2label[img] = labels[i - 1]
95+
file_list = batch_images_from_tar(data_file, dataset_name, img2label)
96+
97+
def reader():
98+
for file in open(file_list):
99+
file = file.strip()
100+
batch = None
101+
with open(file, 'r') as f:
102+
batch = cPickle.load(f)
103+
data = batch['data']
104+
labels = batch['label']
105+
for sample, label in itertools.izip(data, batch['label']):
106+
yield sample, int(label)
107+
108+
return paddle.reader.xmap_readers(mapper, reader,
109+
cpu_count(), buffered_size)
110+
111+
112+
def train(mapper=default_mapper, buffered_size=1024):
113+
'''
114+
Create flowers training set reader.
115+
It returns a reader, each sample in the reader is
116+
image pixels in [0, 1] and label in [1, 102]
117+
translated from original color image by steps:
118+
1. resize to 256*256
119+
2. random crop to 224*224
120+
3. flatten
121+
:param mapper: a function to map sample.
122+
:type mapper: callable
123+
:param buffered_size: the size of buffer used to process images
124+
:type buffered_size: int
125+
:return: train data reader
126+
:rtype: callable
127+
'''
128+
return reader_creator(
129+
download(DATA_URL, 'flowers', DATA_MD5),
130+
download(LABEL_URL, 'flowers', LABEL_MD5),
131+
download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper,
132+
buffered_size)
133+
134+
135+
def test(mapper=default_mapper, buffered_size=1024):
136+
'''
137+
Create flowers test set reader.
138+
It returns a reader, each sample in the reader is
139+
image pixels in [0, 1] and label in [1, 102]
140+
translated from original color image by steps:
141+
1. resize to 256*256
142+
2. random crop to 224*224
143+
3. flatten
144+
:param mapper: a function to map sample.
145+
:type mapper: callable
146+
:param buffered_size: the size of buffer used to process images
147+
:type buffered_size: int
148+
:return: test data reader
149+
:rtype: callable
150+
'''
151+
return reader_creator(
152+
download(DATA_URL, 'flowers', DATA_MD5),
153+
download(LABEL_URL, 'flowers', LABEL_MD5),
154+
download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper,
155+
buffered_size)
156+
157+
158+
def valid(mapper=default_mapper, buffered_size=1024):
159+
'''
160+
Create flowers validation set reader.
161+
It returns a reader, each sample in the reader is
162+
image pixels in [0, 1] and label in [1, 102]
163+
translated from original color image by steps:
164+
1. resize to 256*256
165+
2. random crop to 224*224
166+
3. flatten
167+
:param mapper: a function to map sample.
168+
:type mapper: callable
169+
:param buffered_size: the size of buffer used to process images
170+
:type buffered_size: int
171+
:return: test data reader
172+
:rtype: callable
173+
'''
174+
return reader_creator(
175+
download(DATA_URL, 'flowers', DATA_MD5),
176+
download(LABEL_URL, 'flowers', LABEL_MD5),
177+
download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper,
178+
buffered_size)
179+
180+
181+
def fetch():
182+
download(DATA_URL, 'flowers', DATA_MD5)
183+
download(LABEL_URL, 'flowers', LABEL_MD5)
184+
download(SETID_URL, 'flowers', SETID_MD5)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle.v2.dataset.flowers
16+
import unittest
17+
18+
19+
class TestFlowers(unittest.TestCase):
20+
def check_reader(self, reader):
21+
sum = 0
22+
label = 0
23+
size = 224 * 224 * 3
24+
for l in reader():
25+
self.assertEqual(l[0].size, size)
26+
if l[1] > label:
27+
label = l[1]
28+
sum += 1
29+
return sum, label
30+
31+
def test_train(self):
32+
instances, max_label_value = self.check_reader(
33+
paddle.v2.dataset.flowers.train())
34+
self.assertEqual(instances, 1020)
35+
self.assertEqual(max_label_value, 102)
36+
37+
def test_test(self):
38+
instances, max_label_value = self.check_reader(
39+
paddle.v2.dataset.flowers.test())
40+
self.assertEqual(instances, 6149)
41+
self.assertEqual(max_label_value, 102)
42+
43+
def test_valid(self):
44+
instances, max_label_value = self.check_reader(
45+
paddle.v2.dataset.flowers.valid())
46+
self.assertEqual(instances, 1020)
47+
self.assertEqual(max_label_value, 102)
48+
49+
50+
if __name__ == '__main__':
51+
unittest.main()

python/paddle/v2/image.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import numpy as np
22
try:
33
import cv2
4-
except:
5-
print(
6-
"import cv2 error, please install opencv-python: pip install opencv-python"
7-
)
4+
except ImportError:
5+
cv2 = None
6+
import os
7+
import tarfile
8+
import cPickle
89

910
__all__ = [
10-
"load_image", "resize_short", "to_chw", "center_crop", "random_crop",
11-
"left_right_flip", "simple_transform", "load_and_transform"
11+
"load_image_bytes", "load_image", "resize_short", "to_chw", "center_crop",
12+
"random_crop", "left_right_flip", "simple_transform", "load_and_transform",
13+
"batch_images_from_tar"
1214
]
1315
"""
1416
This file contains some common interfaces for image preprocess.
@@ -28,6 +30,90 @@
2830
"""
2931

3032

33+
def batch_images_from_tar(data_file,
34+
dataset_name,
35+
img2label,
36+
num_per_batch=1024):
37+
"""
38+
Read images from tar file and batch them into batch file.
39+
param data_file: path of image tar file
40+
type data_file: string
41+
param dataset_name: 'train','test' or 'valid'
42+
type dataset_name: string
43+
param img2label: a dic with image file name as key
44+
and image's label as value
45+
type img2label: dic
46+
param num_per_batch: image number per batch file
47+
type num_per_batch: int
48+
return: path of list file containing paths of batch file
49+
rtype: string
50+
"""
51+
batch_dir = data_file + "_batch"
52+
out_path = "%s/%s" % (batch_dir, dataset_name)
53+
meta_file = "%s/%s.txt" % (batch_dir, dataset_name)
54+
55+
if os.path.exists(out_path):
56+
return meta_file
57+
else:
58+
os.makedirs(out_path)
59+
60+
tf = tarfile.open(data_file)
61+
mems = tf.getmembers()
62+
data = []
63+
labels = []
64+
file_id = 0
65+
for mem in mems:
66+
if mem.name in img2label:
67+
data.append(tf.extractfile(mem).read())
68+
labels.append(img2label[mem.name])
69+
if len(data) == num_per_batch:
70+
output = {}
71+
output['label'] = labels
72+
output['data'] = data
73+
cPickle.dump(
74+
output,
75+
open('%s/batch_%d' % (out_path, file_id), 'w'),
76+
protocol=cPickle.HIGHEST_PROTOCOL)
77+
file_id += 1
78+
data = []
79+
labels = []
80+
if len(data) > 0:
81+
output = {}
82+
output['label'] = labels
83+
output['data'] = data
84+
cPickle.dump(
85+
output,
86+
open('%s/batch_%d' % (out_path, file_id), 'w'),
87+
protocol=cPickle.HIGHEST_PROTOCOL)
88+
89+
with open(meta_file, 'a') as meta:
90+
for file in os.listdir(out_path):
91+
meta.write(os.path.abspath("%s/%s" % (out_path, file)) + "\n")
92+
return meta_file
93+
94+
95+
def load_image_bytes(bytes, is_color=True):
96+
"""
97+
Load an color or gray image from bytes array.
98+
99+
Example usage:
100+
101+
.. code-block:: python
102+
with open('cat.jpg') as f:
103+
im = load_image_bytes(f.read())
104+
105+
:param bytes: the input image bytes array.
106+
:type file: str
107+
:param is_color: If set is_color True, it will load and
108+
return a color image. Otherwise, it will
109+
load and return a gray image.
110+
"""
111+
flag = 1 if is_color else 0
112+
file_bytes = np.asarray(bytearray(bytes), dtype=np.uint8)
113+
img = cv2.imdecode(file_bytes, flag)
114+
return img
115+
116+
31117
def load_image(file, is_color=True):
32118
"""
33119
Load an color or gray image from the file path.

0 commit comments

Comments
 (0)