Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit fbb27cf

Browse files
author
Taylor Robie
authored
Add fp16 support to official ResNet. (tensorflow#3687)
* Add fp16 support to resnet. * address PR comments * add dtype checking to model definition * delint * more PR comments * few more tweaks * update resnet checkpoints
1 parent 6741cfc commit fbb27cf

File tree

9 files changed

+430
-173
lines changed

9 files changed

+430
-173
lines changed

official/resnet/README.md

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ You can download 190 MB pre-trained versions of ResNet-50 achieving 76.3% and 75
5555

5656
Other versions and formats:
5757

58-
* [ResNet-v2-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnetv2_imagenet_checkpoint.tar.gz)
59-
* [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnetv2_imagenet_savedmodel.tar.gz)
60-
* [ResNet-v2-ImageNet Frozen Graph](http://download.tensorflow.org/models/official/resnetv2_imagenet_frozen_graph.pb)
61-
* [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnetv1_imagenet_checkpoint.tar.gz)
62-
* [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnetv1_imagenet_savedmodel.tar.gz)
63-
* [ResNet-v1-ImageNet Frozen Graph](http://download.tensorflow.org/models/official/resnetv1_imagenet_frozen_graph.pb)
58+
* [ResNet-v2-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v2_imagenet_checkpoint.tar.gz)
59+
* [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz)
60+
* [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz)
61+
* [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz)

official/resnet/cifar10_main.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ class Cifar10Model(resnet_model.Model):
145145
"""Model class with appropriate defaults for CIFAR-10 data."""
146146

147147
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
148-
version=resnet_model.DEFAULT_VERSION):
148+
version=resnet_model.DEFAULT_VERSION,
149+
dtype=resnet_model.DEFAULT_DTYPE):
149150
"""These are the parameters that work for CIFAR-10 data.
150151
151152
Args:
@@ -156,6 +157,7 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
156157
enables users to extend the same model to their own datasets.
157158
version: Integer representing which version of the ResNet network to use.
158159
See README for details. Valid values: [1, 2]
160+
dtype: The TensorFlow dtype to use for calculations.
159161
160162
Raises:
161163
ValueError: if invalid resnet_size is chosen
@@ -180,7 +182,9 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
180182
block_strides=[1, 2, 2],
181183
final_size=64,
182184
version=version,
183-
data_format=data_format)
185+
data_format=data_format,
186+
dtype=dtype
187+
)
184188

185189

186190
def cifar10_model_fn(features, labels, mode, params):
@@ -204,15 +208,22 @@ def cifar10_model_fn(features, labels, mode, params):
204208
def loss_filter_fn(_):
205209
return True
206210

207-
return resnet_run_loop.resnet_model_fn(features, labels, mode, Cifar10Model,
208-
resnet_size=params['resnet_size'],
209-
weight_decay=weight_decay,
210-
learning_rate_fn=learning_rate_fn,
211-
momentum=0.9,
212-
data_format=params['data_format'],
213-
version=params['version'],
214-
loss_filter_fn=loss_filter_fn,
215-
multi_gpu=params['multi_gpu'])
211+
return resnet_run_loop.resnet_model_fn(
212+
features=features,
213+
labels=labels,
214+
mode=mode,
215+
model_class=Cifar10Model,
216+
resnet_size=params['resnet_size'],
217+
weight_decay=weight_decay,
218+
learning_rate_fn=learning_rate_fn,
219+
momentum=0.9,
220+
data_format=params['data_format'],
221+
version=params['version'],
222+
loss_scale=params['loss_scale'],
223+
loss_filter_fn=loss_filter_fn,
224+
multi_gpu=params['multi_gpu'],
225+
dtype=params['dtype']
226+
)
216227

217228

218229
def main(argv):

official/resnet/cifar10_test.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -71,38 +71,61 @@ def test_dataset_input_fn(self):
7171
for pixel in row:
7272
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
7373

74+
def _cifar10_model_fn_helper(self, mode, version, dtype, multi_gpu=False):
75+
with tf.Graph().as_default() as g:
76+
input_fn = cifar10_main.get_synth_input_fn()
77+
dataset = input_fn(True, '', _BATCH_SIZE)
78+
iterator = dataset.make_one_shot_iterator()
79+
features, labels = iterator.get_next()
80+
spec = cifar10_main.cifar10_model_fn(
81+
features, labels, mode, {
82+
'dtype': dtype,
83+
'resnet_size': 32,
84+
'data_format': 'channels_last',
85+
'batch_size': _BATCH_SIZE,
86+
'version': version,
87+
'loss_scale': 128 if dtype == tf.float16 else 1,
88+
'multi_gpu': multi_gpu
89+
})
90+
91+
predictions = spec.predictions
92+
self.assertAllEqual(predictions['probabilities'].shape,
93+
(_BATCH_SIZE, 10))
94+
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
95+
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
96+
self.assertEqual(predictions['classes'].dtype, tf.int64)
97+
98+
if mode != tf.estimator.ModeKeys.PREDICT:
99+
loss = spec.loss
100+
self.assertAllEqual(loss.shape, ())
101+
self.assertEqual(loss.dtype, tf.float32)
102+
103+
if mode == tf.estimator.ModeKeys.EVAL:
104+
eval_metric_ops = spec.eval_metric_ops
105+
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
106+
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
107+
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
108+
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
109+
110+
for v in tf.trainable_variables():
111+
self.assertEqual(v.dtype.base_dtype, tf.float32)
112+
113+
tensors_to_check = ('initial_conv:0', 'block_layer1:0', 'block_layer2:0',
114+
'block_layer3:0', 'final_reduce_mean:0',
115+
'final_dense:0')
116+
117+
for tensor_name in tensors_to_check:
118+
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
119+
self.assertEqual(tensor.dtype, dtype,
120+
'Tensor {} has dtype {}, while dtype {} was '
121+
'expected'.format(tensor, tensor.dtype,
122+
dtype))
123+
74124
def cifar10_model_fn_helper(self, mode, version, multi_gpu=False):
75-
input_fn = cifar10_main.get_synth_input_fn()
76-
dataset = input_fn(True, '', _BATCH_SIZE)
77-
iterator = dataset.make_one_shot_iterator()
78-
features, labels = iterator.get_next()
79-
spec = cifar10_main.cifar10_model_fn(
80-
features, labels, mode, {
81-
'resnet_size': 32,
82-
'data_format': 'channels_last',
83-
'batch_size': _BATCH_SIZE,
84-
'version': version,
85-
'multi_gpu': multi_gpu
86-
})
87-
88-
predictions = spec.predictions
89-
self.assertAllEqual(predictions['probabilities'].shape,
90-
(_BATCH_SIZE, 10))
91-
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
92-
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
93-
self.assertEqual(predictions['classes'].dtype, tf.int64)
94-
95-
if mode != tf.estimator.ModeKeys.PREDICT:
96-
loss = spec.loss
97-
self.assertAllEqual(loss.shape, ())
98-
self.assertEqual(loss.dtype, tf.float32)
99-
100-
if mode == tf.estimator.ModeKeys.EVAL:
101-
eval_metric_ops = spec.eval_metric_ops
102-
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
103-
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
104-
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
105-
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
125+
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
126+
multi_gpu=multi_gpu)
127+
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
128+
multi_gpu=multi_gpu)
106129

107130
def test_cifar10_model_fn_train_mode_v1(self):
108131
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
@@ -130,19 +153,22 @@ def test_cifar10_model_fn_predict_mode_v1(self):
130153
def test_cifar10_model_fn_predict_mode_v2(self):
131154
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2)
132155

133-
def test_cifar10model_shape(self):
156+
def _test_cifar10model_shape(self, version):
134157
batch_size = 135
135158
num_classes = 246
136159

137-
for version in (1, 2):
138-
model = cifar10_main.Cifar10Model(
139-
32, data_format='channels_last', num_classes=num_classes,
140-
version=version)
141-
fake_input = tf.random_uniform(
142-
[batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
143-
output = model(fake_input, training=True)
160+
model = cifar10_main.Cifar10Model(32, data_format='channels_last',
161+
num_classes=num_classes, version=version)
162+
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
163+
output = model(fake_input, training=True)
164+
165+
self.assertAllEqual(output.shape, (batch_size, num_classes))
166+
167+
def test_cifar10model_shape_v1(self):
168+
self._test_cifar10model_shape(version=1)
144169

145-
self.assertAllEqual(output.shape, (batch_size, num_classes))
170+
def test_cifar10model_shape_v2(self):
171+
self._test_cifar10model_shape(version=2)
146172

147173
def test_cifar10_end_to_end_synthetic_v1(self):
148174
integration.run_synthetic(

official/resnet/imagenet_main.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ class ImagenetModel(resnet_model.Model):
203203
"""Model class with appropriate defaults for Imagenet data."""
204204

205205
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
206-
version=resnet_model.DEFAULT_VERSION):
206+
version=resnet_model.DEFAULT_VERSION,
207+
dtype=resnet_model.DEFAULT_DTYPE):
207208
"""These are the parameters that work for Imagenet data.
208209
209210
Args:
@@ -214,6 +215,7 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
214215
enables users to extend the same model to their own datasets.
215216
version: Integer representing which version of the ResNet network to use.
216217
See README for details. Valid values: [1, 2]
218+
dtype: The TensorFlow dtype to use for calculations.
217219
"""
218220

219221
# For bigger models, we want to use "bottleneck" layers
@@ -239,7 +241,9 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
239241
block_strides=[1, 2, 2, 2],
240242
final_size=final_size,
241243
version=version,
242-
data_format=data_format)
244+
data_format=data_format,
245+
dtype=dtype
246+
)
243247

244248

245249
def _get_block_sizes(resnet_size):
@@ -283,15 +287,22 @@ def imagenet_model_fn(features, labels, mode, params):
283287
num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
284288
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])
285289

286-
return resnet_run_loop.resnet_model_fn(features, labels, mode, ImagenetModel,
287-
resnet_size=params['resnet_size'],
288-
weight_decay=1e-4,
289-
learning_rate_fn=learning_rate_fn,
290-
momentum=0.9,
291-
data_format=params['data_format'],
292-
version=params['version'],
293-
loss_filter_fn=None,
294-
multi_gpu=params['multi_gpu'])
290+
return resnet_run_loop.resnet_model_fn(
291+
features=features,
292+
labels=labels,
293+
mode=mode,
294+
model_class=ImagenetModel,
295+
resnet_size=params['resnet_size'],
296+
weight_decay=1e-4,
297+
learning_rate_fn=learning_rate_fn,
298+
momentum=0.9,
299+
data_format=params['data_format'],
300+
version=params['version'],
301+
loss_scale=params['loss_scale'],
302+
loss_filter_fn=None,
303+
multi_gpu=params['multi_gpu'],
304+
dtype=params['dtype']
305+
)
295306

296307

297308
def main(argv):

0 commit comments

Comments
 (0)