Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit fd1d178

Browse files
guptapriyaTaylor Robie
authored and
Taylor Robie
committed
Add minor performance improvements to resnet input pipeline (tensorflow#4340)
* Remove one hot labels, Add drop_remainder to batch, Use parallel interleve in imagenet dataset. * minor lint fix * Don't try to read the files twice... * Add explanation for cycle_length
1 parent 419bc6e commit fd1d178

File tree

4 files changed

+16
-14
lines changed

4 files changed

+16
-14
lines changed

official/resnet/cifar10_main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def parse_record(raw_record, is_training):
7373
# The first byte represents the label, which we convert from uint8 to int32
7474
# and then to one-hot.
7575
label = tf.cast(record_vector[0], tf.int32)
76-
label = tf.one_hot(label, _NUM_CLASSES)
7776

7877
# The remaining bytes after the label represent the image, which we reshape
7978
# from [depth * height * width] to [depth, height, width].

official/resnet/cifar10_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ def test_dataset_input_fn(self):
6464
lambda val: cifar10_main.parse_record(val, False))
6565
image, label = fake_dataset.make_one_shot_iterator().get_next()
6666

67-
self.assertAllEqual(label.shape, (10,))
67+
self.assertAllEqual(label.shape, ())
6868
self.assertAllEqual(image.shape, (_HEIGHT, _WIDTH, _NUM_CHANNELS))
6969

7070
with self.test_session() as sess:
7171
image, label = sess.run([image, label])
7272

73-
self.assertAllEqual(label, np.array([int(i == 7) for i in range(10)]))
73+
self.assertEqual(label, 7)
7474

7575
for row in image:
7676
for pixel in row:

official/resnet/imagenet_main.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
}
4040

4141
_NUM_TRAIN_FILES = 1024
42-
_SHUFFLE_BUFFER = 1500
42+
_SHUFFLE_BUFFER = 10000
4343

4444
DATASET_NAME = 'ImageNet'
4545

@@ -152,8 +152,6 @@ def parse_record(raw_record, is_training):
152152
num_channels=_NUM_CHANNELS,
153153
is_training=is_training)
154154

155-
label = tf.one_hot(tf.reshape(label, shape=[]), _NUM_CLASSES)
156-
157155
return image, label
158156

159157

@@ -176,8 +174,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
176174
# Shuffle the input files
177175
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
178176

179-
# Convert to individual records
180-
dataset = dataset.flat_map(tf.data.TFRecordDataset)
177+
# Convert to individual records.
178+
# cycle_length = 10 means 10 files will be read and deserialized in parallel.
179+
# This number is low enough to not cause too much contention on small systems
180+
# but high enough to provide the benefits of parallelization. You may want
181+
# to increase this number if you have a large number of CPU cores.
182+
dataset = dataset.apply(tf.contrib.data.parallel_interleave(
183+
tf.data.TFRecordDataset, cycle_length=10))
181184

182185
return resnet_run_loop.process_record_dataset(
183186
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,

official/resnet/resnet_run_loop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
7979
tf.contrib.data.map_and_batch(
8080
lambda value: parse_record_fn(value, is_training),
8181
batch_size=batch_size,
82-
num_parallel_batches=1))
82+
num_parallel_batches=1,
83+
drop_remainder=True))
8384

8485
# Operations between the final prefetch and the get_next call to the iterator
8586
# will happen synchronously during run time. We prefetch here again to
@@ -111,7 +112,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
111112
"""
112113
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument
113114
images = tf.zeros((batch_size, height, width, num_channels), tf.float32)
114-
labels = tf.zeros((batch_size, num_classes), tf.int32)
115+
labels = tf.zeros((batch_size), tf.int32)
115116
return tf.data.Dataset.from_tensors((images, labels)).repeat()
116117

117118
return input_fn
@@ -227,8 +228,8 @@ def resnet_model_fn(features, labels, mode, model_class,
227228
})
228229

229230
# Calculate loss, which includes softmax cross entropy and L2 regularization.
230-
cross_entropy = tf.losses.softmax_cross_entropy(
231-
logits=logits, onehot_labels=labels)
231+
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
232+
logits=logits, labels=labels)
232233

233234
# Create a tensor named cross_entropy for logging purposes.
234235
tf.identity(cross_entropy, name='cross_entropy')
@@ -282,8 +283,7 @@ def exclude_batch_norm(name):
282283
train_op = None
283284

284285
if not tf.contrib.distribute.has_distribution_strategy():
285-
accuracy = tf.metrics.accuracy(
286-
tf.argmax(labels, axis=1), predictions['classes'])
286+
accuracy = tf.metrics.accuracy(labels, predictions['classes'])
287287
else:
288288
# Metrics are currently not compatible with distribution strategies during
289289
# training. This does not affect the overall performance of the model.

0 commit comments

Comments
 (0)