Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 481728d

Browse files
authored
Merge pull request tensorflow#5225 from tfboyd/resnet_synthetic_fix
ResNet synthetic data performance enhancement.
2 parents e0f6a39 + 967133c commit 481728d

File tree

5 files changed

+44
-23
lines changed

5 files changed

+44
-23
lines changed

official/resnet/cifar10_main.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
135135
)
136136

137137

138-
def get_synth_input_fn():
138+
def get_synth_input_fn(dtype):
139139
return resnet_run_loop.get_synth_input_fn(
140-
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES)
140+
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES, dtype=dtype)
141141

142142

143143
###############################################################################
@@ -243,8 +243,9 @@ def run_cifar(flags_obj):
243243
Args:
244244
flags_obj: An object containing parsed flag values.
245245
"""
246-
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
247-
or input_fn)
246+
input_function = (flags_obj.use_synthetic_data and
247+
get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
248+
input_fn)
248249
resnet_run_loop.resnet_main(
249250
flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
250251
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])

official/resnet/cifar10_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ def test_dataset_input_fn(self):
7777
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
7878

7979
def cifar10_model_fn_helper(self, mode, resnet_version, dtype):
80-
input_fn = cifar10_main.get_synth_input_fn()
80+
input_fn = cifar10_main.get_synth_input_fn(dtype)
8181
dataset = input_fn(True, '', _BATCH_SIZE)
82-
iterator = dataset.make_one_shot_iterator()
82+
iterator = dataset.make_initializable_iterator()
8383
features, labels = iterator.get_next()
8484
spec = cifar10_main.cifar10_model_fn(
8585
features, labels, mode, {

official/resnet/imagenet_main.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
196196
)
197197

198198

199-
def get_synth_input_fn():
199+
def get_synth_input_fn(dtype):
200200
return resnet_run_loop.get_synth_input_fn(
201-
_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES)
201+
_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES,
202+
dtype=dtype)
202203

203204

204205
###############################################################################
@@ -331,8 +332,9 @@ def run_imagenet(flags_obj):
331332
Args:
332333
flags_obj: An object containing parsed flag values.
333334
"""
334-
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
335-
or input_fn)
335+
input_function = (flags_obj.use_synthetic_data and
336+
get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
337+
input_fn)
336338

337339
resnet_run_loop.resnet_main(
338340
flags_obj, imagenet_model_fn, input_function, DATASET_NAME,

official/resnet/imagenet_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ def resnet_model_fn_helper(self, mode, resnet_version, dtype):
191191
"""Tests that the EstimatorSpec is given the appropriate arguments."""
192192
tf.train.create_global_step()
193193

194-
input_fn = imagenet_main.get_synth_input_fn()
194+
input_fn = imagenet_main.get_synth_input_fn(dtype)
195195
dataset = input_fn(True, '', _BATCH_SIZE)
196-
iterator = dataset.make_one_shot_iterator()
196+
iterator = dataset.make_initializable_iterator()
197197
features, labels = iterator.get_next()
198198
spec = imagenet_main.imagenet_model_fn(
199199
features, labels, mode, {

official/resnet/resnet_run_loop.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,29 +108,47 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
108108
return dataset
109109

110110

111-
def get_synth_input_fn(height, width, num_channels, num_classes):
112-
"""Returns an input function that returns a dataset with zeroes.
111+
def get_synth_input_fn(height, width, num_channels, num_classes,
112+
dtype=tf.float32):
113+
"""Returns an input function that returns a dataset with random data.
113114
114-
This is useful in debugging input pipeline performance, as it removes all
115-
elements of file reading and image preprocessing.
115+
This input_fn returns a data set that iterates over a set of random data and
116+
bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
117+
copy is still included. This used to find the upper throughput bound when
118+
tunning the full input pipeline.
116119
117120
Args:
118121
height: Integer height that will be used to create a fake image tensor.
119122
width: Integer width that will be used to create a fake image tensor.
120123
num_channels: Integer depth that will be used to create a fake image tensor.
121124
num_classes: Number of classes that should be represented in the fake labels
122125
tensor
126+
dtype: Data type for features/images.
123127
124128
Returns:
125129
An input_fn that can be used in place of a real one to return a dataset
126130
that can be used for iteration.
127131
"""
128-
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument
129-
return model_helpers.generate_synthetic_data(
130-
input_shape=tf.TensorShape([batch_size, height, width, num_channels]),
131-
input_dtype=tf.float32,
132-
label_shape=tf.TensorShape([batch_size]),
133-
label_dtype=tf.int32)
132+
# pylint: disable=unused-argument
133+
def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
134+
"""Returns dataset filled with random data."""
135+
# Synthetic input should be within [0, 255].
136+
inputs = tf.truncated_normal(
137+
[batch_size] + [height, width, num_channels],
138+
dtype=dtype,
139+
mean=127,
140+
stddev=60,
141+
name='synthetic_inputs')
142+
143+
labels = tf.random_uniform(
144+
[batch_size],
145+
minval=0,
146+
maxval=num_classes - 1,
147+
dtype=tf.int32,
148+
name='synthetic_labels')
149+
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
150+
data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
151+
return data
134152

135153
return input_fn
136154

@@ -230,7 +248,7 @@ def resnet_model_fn(features, labels, mode, model_class,
230248

231249
# Generate a summary node for the images
232250
tf.summary.image('images', features, max_outputs=6)
233-
251+
# TODO(tobyboyd): Add cast as part of input pipeline on cpu and remove.
234252
features = tf.cast(features, dtype)
235253

236254
model = model_class(resnet_size, data_format, resnet_version=resnet_version,

0 commit comments

Comments
 (0)