Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 76dbcb5

Browse files
committed
Move tf.cast for fp16 to input pipeline.
1 parent 5856878 commit 76dbcb5

File tree

3 files changed

+26
-14
lines changed

3 files changed

+26
-14
lines changed

official/resnet/cifar10_main.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def get_filenames(is_training, data_dir):
6666
return [os.path.join(data_dir, 'test_batch.bin')]
6767

6868

69-
def parse_record(raw_record, is_training):
69+
def parse_record(raw_record, is_training, dtype):
7070
"""Parse CIFAR-10 image and label from a raw record."""
7171
# Convert bytes to a vector of uint8 that is record_bytes long.
7272
record_vector = tf.decode_raw(raw_record, tf.uint8)
@@ -85,6 +85,7 @@ def parse_record(raw_record, is_training):
8585
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
8686

8787
image = preprocess_image(image, is_training)
88+
image = tf.cast(image, dtype)
8889

8990
return image, label
9091

@@ -107,15 +108,17 @@ def preprocess_image(image, is_training):
107108
return image
108109

109110

110-
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
111-
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
111+
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
112+
dtype=tf.float32):
113+
"""Input function which provides batches for train or eval.
112114
113115
Args:
114116
is_training: A boolean denoting whether the input is for training.
115117
data_dir: The directory containing the input data.
116118
batch_size: The number of samples per batch.
117119
num_epochs: The number of epochs to repeat the dataset.
118120
num_gpus: The number of gpus used for training.
121+
dtype: Data type to use for images/features
119122
120123
Returns:
121124
A dataset that can be used for iteration.
@@ -131,7 +134,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
131134
parse_record_fn=parse_record,
132135
num_epochs=num_epochs,
133136
num_gpus=num_gpus,
134-
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None
137+
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None,
138+
dtype=dtype
135139
)
136140

137141

official/resnet/imagenet_main.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _parse_example_proto(example_serialized):
129129
return features['image/encoded'], label, bbox
130130

131131

132-
def parse_record(raw_record, is_training):
132+
def parse_record(raw_record, is_training, dtype):
133133
"""Parses a record containing a training example of an image.
134134
135135
The input record is parsed into a label and image, and the image is passed
@@ -139,6 +139,7 @@ def parse_record(raw_record, is_training):
139139
raw_record: scalar Tensor tf.string containing a serialized
140140
Example protocol buffer.
141141
is_training: A boolean denoting whether the input is for training.
142+
dtype: data type to use for images/features.
142143
143144
Returns:
144145
Tuple with processed image tensor and one-hot-encoded label tensor.
@@ -152,11 +153,13 @@ def parse_record(raw_record, is_training):
152153
output_width=_DEFAULT_IMAGE_SIZE,
153154
num_channels=_NUM_CHANNELS,
154155
is_training=is_training)
156+
image = tf.cast(image, dtype)
155157

156158
return image, label
157159

158160

159-
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
161+
def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None,
162+
dtype=tf.float32):
160163
"""Input function which provides batches for train or eval.
161164
162165
Args:
@@ -165,6 +168,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
165168
batch_size: The number of samples per batch.
166169
num_epochs: The number of epochs to repeat the dataset.
167170
num_gpus: The number of gpus used for training.
171+
dtype: Data type to use for images/features
168172
169173
Returns:
170174
A dataset that can be used for iteration.
@@ -192,7 +196,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
192196
parse_record_fn=parse_record,
193197
num_epochs=num_epochs,
194198
num_gpus=num_gpus,
195-
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None
199+
examples_per_epoch=_NUM_IMAGES['train'] if is_training else None,
200+
dtype=dtype
196201
)
197202

198203

official/resnet/resnet_run_loop.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
################################################################################
4646
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
4747
parse_record_fn, num_epochs=1, num_gpus=None,
48-
examples_per_epoch=None):
48+
examples_per_epoch=None, dtype=tf.float32):
4949
"""Given a Dataset with raw records, return an iterator over the records.
5050
5151
Args:
@@ -60,6 +60,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
6060
num_epochs: The number of epochs to repeat the dataset.
6161
num_gpus: The number of gpus used for training.
6262
examples_per_epoch: The number of examples in an epoch.
63+
dtype: Data type to use for images/features.
6364
6465
Returns:
6566
Dataset of (image, label) pairs ready for iteration.
@@ -92,7 +93,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
9293
# batch_size is almost always much greater than the number of CPU cores.
9394
dataset = dataset.apply(
9495
tf.contrib.data.map_and_batch(
95-
lambda value: parse_record_fn(value, is_training),
96+
lambda value: parse_record_fn(value, is_training, dtype),
9697
batch_size=batch_size,
9798
num_parallel_batches=1,
9899
drop_remainder=False))
@@ -248,8 +249,8 @@ def resnet_model_fn(features, labels, mode, model_class,
248249

249250
# Generate a summary node for the images
250251
tf.summary.image('images', features, max_outputs=6)
251-
# TODO(tobyboyd): Add cast as part of input pipeline on cpu and remove.
252-
features = tf.cast(features, dtype)
252+
# Checks that features/images have same data type being used for calculations.
253+
assert features.dtype == dtype
253254

254255
model = model_class(resnet_size, data_format, resnet_version=resnet_version,
255256
dtype=dtype)
@@ -454,14 +455,16 @@ def input_fn_train(num_epochs):
454455
batch_size=distribution_utils.per_device_batch_size(
455456
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
456457
num_epochs=num_epochs,
457-
num_gpus=flags_core.get_num_gpus(flags_obj))
458+
num_gpus=flags_core.get_num_gpus(flags_obj),
459+
dtype=flags_core.get_tf_dtype(flags_obj))
458460

459461
def input_fn_eval():
460462
return input_function(
461463
is_training=False, data_dir=flags_obj.data_dir,
462464
batch_size=distribution_utils.per_device_batch_size(
463465
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
464-
num_epochs=1)
466+
num_epochs=1,
467+
dtype=flags_core.get_tf_dtype(flags_obj))
465468

466469
if flags_obj.eval_only or not flags_obj.train_epochs:
467470
# If --eval_only is set, perform a single loop with zero train epochs.
@@ -533,7 +536,7 @@ def define_resnet_flags(resnet_size_choices=None):
533536
'If not None initialize all the network except the final layer with '
534537
'these values'))
535538
flags.DEFINE_boolean(
536-
name="eval_only", default=False,
539+
name='eval_only', default=False,
537540
help=flags_core.help_wrap('Skip training and only perform evaluation on '
538541
'the latest checkpoint.'))
539542

0 commit comments

Comments
 (0)