Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 18d05ad

Browse files
author
Taylor Robie
authored
Restore ResNet Distribution Strategies (tensorflow#4134)
* Revert 823da31. This restores distribution strategies for resnet. This commit is not a direct revert due to significant merge conflict resolution. * fix flags test * npc is no longer used in resnet
1 parent 51a2b44 commit 18d05ad

File tree

10 files changed

+201
-260
lines changed

10 files changed

+201
-260
lines changed

official/mnist/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def create_model(data_format):
8787

8888

8989
def define_mnist_flags():
90-
flags_core.define_base()
90+
flags_core.define_base(multi_gpu=True, num_gpu=False)
9191
flags_core.define_image()
9292
flags.adopt_module_key_flags(flags_core)
9393
flags_core.set_defaults(data_dir='/tmp/mnist_data',

official/resnet/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,13 @@ Other versions and formats:
5959
* [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz)
6060
* [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz)
6161
* [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz)
62+
63+
## Compute Devices
64+
Training is accomplished using the DistributionStrategies API. (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/README.md)
65+
66+
The appropriate distribution strategy is chosen based on the `--num_gpus` flag. By default this flag is one if TensorFlow is compiled with CUDA, and zero otherwise.
67+
68+
num_gpus:
69+
+ 0: Use OneDeviceStrategy and train on CPU.
70+
+ 1: Use OneDeviceStrategy and train on GPU.
71+
+ 2+: Use MirroredStrategy (data parallelism) to distribute a batch between devices.

official/resnet/cifar10_main.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,34 +105,25 @@ def preprocess_image(image, is_training):
105105
return image
106106

107107

108-
def input_fn(is_training, data_dir, batch_size, num_epochs=1,
109-
num_parallel_calls=1, multi_gpu=False):
108+
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
110109
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
111110
112111
Args:
113112
is_training: A boolean denoting whether the input is for training.
114113
data_dir: The directory containing the input data.
115114
batch_size: The number of samples per batch.
116115
num_epochs: The number of epochs to repeat the dataset.
117-
num_parallel_calls: The number of records that are processed in parallel.
118-
This can be optimized per data set but for generally homogeneous data
119-
sets, should be approximately the number of available CPU cores.
120-
multi_gpu: Whether this is run multi-GPU. Note that this is only required
121-
currently to handle the batch leftovers, and can be removed
122-
when that is handled directly by Estimator.
123116
124117
Returns:
125118
A dataset that can be used for iteration.
126119
"""
127120
filenames = get_filenames(is_training, data_dir)
128121
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
129122

130-
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
131-
132123
return resnet_run_loop.process_record_dataset(
133124
dataset, is_training, batch_size, _NUM_IMAGES['train'],
134-
parse_record, num_epochs, num_parallel_calls,
135-
examples_per_epoch=num_images, multi_gpu=multi_gpu)
125+
parse_record, num_epochs,
126+
)
136127

137128

138129
def get_synth_input_fn():
@@ -221,7 +212,6 @@ def loss_filter_fn(_):
221212
version=params['version'],
222213
loss_scale=params['loss_scale'],
223214
loss_filter_fn=loss_filter_fn,
224-
multi_gpu=params['multi_gpu'],
225215
dtype=params['dtype']
226216
)
227217

official/resnet/cifar10_test.py

Lines changed: 44 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -76,87 +76,63 @@ def test_dataset_input_fn(self):
7676
for pixel in row:
7777
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
7878

79-
def _cifar10_model_fn_helper(self, mode, version, dtype, multi_gpu=False):
80-
with tf.Graph().as_default() as g:
81-
input_fn = cifar10_main.get_synth_input_fn()
82-
dataset = input_fn(True, '', _BATCH_SIZE)
83-
iterator = dataset.make_one_shot_iterator()
84-
features, labels = iterator.get_next()
85-
spec = cifar10_main.cifar10_model_fn(
86-
features, labels, mode, {
87-
'dtype': dtype,
88-
'resnet_size': 32,
89-
'data_format': 'channels_last',
90-
'batch_size': _BATCH_SIZE,
91-
'version': version,
92-
'loss_scale': 128 if dtype == tf.float16 else 1,
93-
'multi_gpu': multi_gpu
94-
})
95-
96-
predictions = spec.predictions
97-
self.assertAllEqual(predictions['probabilities'].shape,
98-
(_BATCH_SIZE, 10))
99-
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
100-
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
101-
self.assertEqual(predictions['classes'].dtype, tf.int64)
102-
103-
if mode != tf.estimator.ModeKeys.PREDICT:
104-
loss = spec.loss
105-
self.assertAllEqual(loss.shape, ())
106-
self.assertEqual(loss.dtype, tf.float32)
107-
108-
if mode == tf.estimator.ModeKeys.EVAL:
109-
eval_metric_ops = spec.eval_metric_ops
110-
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
111-
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
112-
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
113-
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
114-
115-
for v in tf.trainable_variables():
116-
self.assertEqual(v.dtype.base_dtype, tf.float32)
117-
118-
tensors_to_check = ('initial_conv:0', 'block_layer1:0', 'block_layer2:0',
119-
'block_layer3:0', 'final_reduce_mean:0',
120-
'final_dense:0')
121-
122-
for tensor_name in tensors_to_check:
123-
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
124-
self.assertEqual(tensor.dtype, dtype,
125-
'Tensor {} has dtype {}, while dtype {} was '
126-
'expected'.format(tensor, tensor.dtype,
127-
dtype))
128-
129-
def cifar10_model_fn_helper(self, mode, version, multi_gpu=False):
130-
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
131-
multi_gpu=multi_gpu)
132-
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
133-
multi_gpu=multi_gpu)
79+
def cifar10_model_fn_helper(self, mode, version, dtype):
80+
input_fn = cifar10_main.get_synth_input_fn()
81+
dataset = input_fn(True, '', _BATCH_SIZE)
82+
iterator = dataset.make_one_shot_iterator()
83+
features, labels = iterator.get_next()
84+
spec = cifar10_main.cifar10_model_fn(
85+
features, labels, mode, {
86+
'dtype': dtype,
87+
'resnet_size': 32,
88+
'data_format': 'channels_last',
89+
'batch_size': _BATCH_SIZE,
90+
'version': version,
91+
'loss_scale': 128 if dtype == tf.float16 else 1,
92+
})
93+
94+
predictions = spec.predictions
95+
self.assertAllEqual(predictions['probabilities'].shape,
96+
(_BATCH_SIZE, 10))
97+
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
98+
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
99+
self.assertEqual(predictions['classes'].dtype, tf.int64)
100+
101+
if mode != tf.estimator.ModeKeys.PREDICT:
102+
loss = spec.loss
103+
self.assertAllEqual(loss.shape, ())
104+
self.assertEqual(loss.dtype, tf.float32)
105+
106+
if mode == tf.estimator.ModeKeys.EVAL:
107+
eval_metric_ops = spec.eval_metric_ops
108+
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
109+
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
110+
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
111+
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
134112

135113
def test_cifar10_model_fn_train_mode_v1(self):
136-
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
137-
138-
def test_cifar10_model_fn_trainmode__v2(self):
139-
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)
140-
141-
def test_cifar10_model_fn_train_mode_multi_gpu_v1(self):
142114
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
143-
multi_gpu=True)
115+
dtype=tf.float32)
144116

145-
def test_cifar10_model_fn_train_mode_multi_gpu_v2(self):
117+
def test_cifar10_model_fn_trainmode__v2(self):
146118
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
147-
multi_gpu=True)
119+
dtype=tf.float32)
148120

149121
def test_cifar10_model_fn_eval_mode_v1(self):
150-
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1)
122+
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
123+
dtype=tf.float32)
151124

152125
def test_cifar10_model_fn_eval_mode_v2(self):
153-
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2)
126+
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
127+
dtype=tf.float32)
154128

155129
def test_cifar10_model_fn_predict_mode_v1(self):
156-
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1)
130+
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
131+
dtype=tf.float32)
157132

158133
def test_cifar10_model_fn_predict_mode_v2(self):
159-
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2)
134+
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
135+
dtype=tf.float32)
160136

161137
def _test_cifar10model_shape(self, version):
162138
batch_size = 135

official/resnet/imagenet_main.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -156,21 +156,14 @@ def parse_record(raw_record, is_training):
156156
return image, label
157157

158158

159-
def input_fn(is_training, data_dir, batch_size, num_epochs=1,
160-
num_parallel_calls=1, multi_gpu=False):
159+
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
161160
"""Input function which provides batches for train or eval.
162161
163162
Args:
164163
is_training: A boolean denoting whether the input is for training.
165164
data_dir: The directory containing the input data.
166165
batch_size: The number of samples per batch.
167166
num_epochs: The number of epochs to repeat the dataset.
168-
num_parallel_calls: The number of records that are processed in parallel.
169-
This can be optimized per data set but for generally homogeneous data
170-
sets, should be approximately the number of available CPU cores.
171-
multi_gpu: Whether this is run multi-GPU. Note that this is only required
172-
currently to handle the batch leftovers, and can be removed
173-
when that is handled directly by Estimator.
174167
175168
Returns:
176169
A dataset that can be used for iteration.
@@ -182,15 +175,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
182175
# Shuffle the input files
183176
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
184177

185-
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
186-
187178
# Convert to individual records
188179
dataset = dataset.flat_map(tf.data.TFRecordDataset)
189180

190181
return resnet_run_loop.process_record_dataset(
191182
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
192-
num_epochs, num_parallel_calls, examples_per_epoch=num_images,
193-
multi_gpu=multi_gpu)
183+
num_epochs
184+
)
194185

195186

196187
def get_synth_input_fn():
@@ -300,7 +291,6 @@ def imagenet_model_fn(features, labels, mode, params):
300291
version=params['version'],
301292
loss_scale=params['loss_scale'],
302293
loss_filter_fn=None,
303-
multi_gpu=params['multi_gpu'],
304294
dtype=params['dtype']
305295
)
306296

official/resnet/imagenet_test.py

Lines changed: 46 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -185,88 +185,66 @@ def test_tensor_shapes_resnet_200_with_gpu_v1(self):
185185
def test_tensor_shapes_resnet_200_with_gpu_v2(self):
186186
self.tensor_shapes_helper(200, version=2, with_gpu=True)
187187

188-
def _resnet_model_fn_helper(self, mode, version, dtype, multi_gpu):
188+
def resnet_model_fn_helper(self, mode, version, dtype):
189189
"""Tests that the EstimatorSpec is given the appropriate arguments."""
190-
with tf.Graph().as_default() as g:
191-
tf.train.create_global_step()
192-
193-
input_fn = imagenet_main.get_synth_input_fn()
194-
dataset = input_fn(True, '', _BATCH_SIZE)
195-
iterator = dataset.make_one_shot_iterator()
196-
features, labels = iterator.get_next()
197-
spec = imagenet_main.imagenet_model_fn(
198-
features, labels, mode, {
199-
'dtype': dtype,
200-
'resnet_size': 50,
201-
'data_format': 'channels_last',
202-
'batch_size': _BATCH_SIZE,
203-
'version': version,
204-
'loss_scale': 128 if dtype == tf.float16 else 1,
205-
'multi_gpu': multi_gpu,
206-
})
207-
208-
predictions = spec.predictions
209-
self.assertAllEqual(predictions['probabilities'].shape,
210-
(_BATCH_SIZE, _LABEL_CLASSES))
211-
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
212-
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
213-
self.assertEqual(predictions['classes'].dtype, tf.int64)
214-
215-
if mode != tf.estimator.ModeKeys.PREDICT:
216-
loss = spec.loss
217-
self.assertAllEqual(loss.shape, ())
218-
self.assertEqual(loss.dtype, tf.float32)
219-
220-
if mode == tf.estimator.ModeKeys.EVAL:
221-
eval_metric_ops = spec.eval_metric_ops
222-
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
223-
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
224-
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
225-
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
226-
227-
tensors_to_check = ('initial_conv:0', 'initial_max_pool:0',
228-
'block_layer1:0', 'block_layer2:0',
229-
'block_layer3:0', 'block_layer4:0',
230-
'final_reduce_mean:0', 'final_dense:0')
231-
232-
for tensor_name in tensors_to_check:
233-
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
234-
self.assertEqual(tensor.dtype, dtype,
235-
'Tensor {} has dtype {}, while dtype {} was '
236-
'expected'.format(tensor, tensor.dtype,
237-
dtype))
238-
239-
def resnet_model_fn_helper(self, mode, version, multi_gpu=False):
240-
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
241-
multi_gpu=multi_gpu)
242-
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
243-
multi_gpu=multi_gpu)
190+
tf.train.create_global_step()
191+
192+
input_fn = imagenet_main.get_synth_input_fn()
193+
dataset = input_fn(True, '', _BATCH_SIZE)
194+
iterator = dataset.make_one_shot_iterator()
195+
features, labels = iterator.get_next()
196+
spec = imagenet_main.imagenet_model_fn(
197+
features, labels, mode, {
198+
'dtype': dtype,
199+
'resnet_size': 50,
200+
'data_format': 'channels_last',
201+
'batch_size': _BATCH_SIZE,
202+
'version': version,
203+
'loss_scale': 128 if dtype == tf.float16 else 1,
204+
})
205+
206+
predictions = spec.predictions
207+
self.assertAllEqual(predictions['probabilities'].shape,
208+
(_BATCH_SIZE, _LABEL_CLASSES))
209+
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
210+
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
211+
self.assertEqual(predictions['classes'].dtype, tf.int64)
212+
213+
if mode != tf.estimator.ModeKeys.PREDICT:
214+
loss = spec.loss
215+
self.assertAllEqual(loss.shape, ())
216+
self.assertEqual(loss.dtype, tf.float32)
217+
218+
if mode == tf.estimator.ModeKeys.EVAL:
219+
eval_metric_ops = spec.eval_metric_ops
220+
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
221+
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
222+
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
223+
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
244224

245225
def test_resnet_model_fn_train_mode_v1(self):
246-
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
247-
248-
def test_resnet_model_fn_train_mode_v2(self):
249-
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)
250-
251-
def test_resnet_model_fn_train_mode_multi_gpu_v1(self):
252226
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
253-
multi_gpu=True)
227+
dtype=tf.float32)
254228

255-
def test_resnet_model_fn_train_mode_multi_gpu_v2(self):
229+
def test_resnet_model_fn_train_mode_v2(self):
256230
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
257-
multi_gpu=True)
231+
dtype=tf.float32)
258232

259233
def test_resnet_model_fn_eval_mode_v1(self):
260-
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1)
234+
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
235+
dtype=tf.float32)
261236

262237
def test_resnet_model_fn_eval_mode_v2(self):
263-
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2)
238+
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
239+
dtype=tf.float32)
264240

265241
def test_resnet_model_fn_predict_mode_v1(self):
266-
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1)
242+
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
243+
dtype=tf.float32)
267244

268245
def test_resnet_model_fn_predict_mode_v2(self):
269-
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2)
246+
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
247+
dtype=tf.float32)
270248

271249
def _test_imagenetmodel_shape(self, version):
272250
batch_size = 135

0 commit comments

Comments
 (0)