forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathflowers.py
More file actions
230 lines (205 loc) · 7.8 KB
/
Copy pathflowers.py
File metadata and controls
230 lines (205 loc) · 7.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module will download dataset from
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
and parse train/test set intopaddle reader creators.
This set contains images of flowers belonging to 102 different categories.
The images were acquired by searching the web and taking pictures. There are a
minimum of 40 images for each category.
The database was used in:
Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
number of classes.Proceedings of the Indian Conference on Computer Vision,
Graphics and Image Processing (2008)
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
"""
from __future__ import print_function
import itertools
import functools
from .common import download
import tarfile
import scipy.io as scio
from paddle.dataset.image import *
from paddle.reader import *
from paddle import compat as cpt
import os
import numpy as np
from multiprocessing import cpu_count
import six
from six.moves import cPickle as pickle
__all__ = ['train', 'test', 'valid']
DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
SETID_URL = 'http://paddlemodels.bj.bcebos.com/flowers/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
TRAIN_FLAG = 'tstid'
TEST_FLAG = 'trnid'
VALID_FLAG = 'valid'
def default_mapper(is_train, sample):
'''
map image bytes data to type needed by model input layer
'''
img, label = sample
img = load_image_bytes(img)
img = simple_transform(
img, 256, 224, is_train, mean=[103.94, 116.78, 123.68])
return img.flatten().astype('float32'), label
train_mapper = functools.partial(default_mapper, True)
test_mapper = functools.partial(default_mapper, False)
def reader_creator(data_file,
label_file,
setid_file,
dataset_name,
mapper,
buffered_size=1024,
use_xmap=True,
cycle=False):
'''
1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/
2. get a reader to read sample from batch file
:param data_file: downloaded data file
:type data_file: string
:param label_file: downloaded label file
:type label_file: string
:param setid_file: downloaded setid file containing information
about how to split dataset
:type setid_file: string
:param dataset_name: data set name (tstid|trnid|valid)
:type dataset_name: string
:param mapper: a function to map image bytes data to type
needed by model input layer
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: data reader
:rtype: callable
'''
labels = scio.loadmat(label_file)['labels'][0]
indexes = scio.loadmat(setid_file)[dataset_name][0]
img2label = {}
for i in indexes:
img = "jpg/image_%05d.jpg" % i
img2label[img] = labels[i - 1]
file_list = batch_images_from_tar(data_file, dataset_name, img2label)
def reader():
while True:
with open(file_list, 'r') as f_list:
for file in f_list:
file = file.strip()
batch = None
with open(file, 'rb') as f:
if six.PY2:
batch = pickle.load(f)
else:
batch = pickle.load(f, encoding='bytes')
if six.PY3:
batch = cpt.to_text(batch)
data_batch = batch['data']
labels_batch = batch['label']
for sample, label in six.moves.zip(data_batch,
labels_batch):
yield sample, int(label) - 1
if not cycle:
break
if use_xmap:
return xmap_readers(mapper, reader, min(4, cpu_count()), buffered_size)
else:
return map_readers(mapper, reader)
def train(mapper=train_mapper, buffered_size=1024, use_xmap=True, cycle=False):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: train data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5),
TRAIN_FLAG,
mapper,
buffered_size,
use_xmap,
cycle=cycle)
def test(mapper=test_mapper, buffered_size=1024, use_xmap=True, cycle=False):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: test data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5),
TEST_FLAG,
mapper,
buffered_size,
use_xmap,
cycle=cycle)
def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers validation set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: test data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper,
buffered_size, use_xmap)
def fetch():
download(DATA_URL, 'flowers', DATA_MD5)
download(LABEL_URL, 'flowers', LABEL_MD5)
download(SETID_URL, 'flowers', SETID_MD5)