Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9bf586d

Browse files
tfboydTaylor Robie
authored and
Taylor Robie
committed
Add 5 epoch warmup to resnet (tensorflow#5176)
* Add 5 epoch warmup * get_lr with warm_up only for imagenet * Add base_lr, remove fp16 unittest arg validation * Remove validation check stopping v1 and FP16
1 parent 981c003 commit 9bf586d

File tree

4 files changed

+25
-29
lines changed

4 files changed

+25
-29
lines changed

official/resnet/cifar10_test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,6 @@ def test_cifar10_end_to_end_synthetic_v2(self):
165165
extra_flags=['-resnet_version', '2']
166166
)
167167

168-
def test_flag_restriction(self):
169-
with self.assertRaises(SystemExit):
170-
integration.run_synthetic(
171-
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
172-
extra_flags=['-resnet_version', '1', "-dtype", "fp16"]
173-
)
174-
175168

176169
if __name__ == '__main__':
177170
tf.test.main()

official/resnet/imagenet_main.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,20 @@ def _get_block_sizes(resnet_size):
285285

286286
def imagenet_model_fn(features, labels, mode, params):
287287
"""Our model_fn for ResNet to be used with our Estimator."""
288+
289+
# Warmup and higher lr may not be valid for fine tuning with small batches
290+
# and smaller numbers of training images.
291+
if params['fine_tune']:
292+
warmup = False
293+
base_lr = .1
294+
else:
295+
warmup = True
296+
base_lr = .128
297+
288298
learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
289299
batch_size=params['batch_size'], batch_denom=256,
290300
num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
291-
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])
301+
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr)
292302

293303
return resnet_run_loop.resnet_model_fn(
294304
features=features,

official/resnet/imagenet_test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,6 @@ def test_imagenet_end_to_end_synthetic_v2_huge(self):
304304
extra_flags=['-resnet_version', '2', '-resnet_size', '200']
305305
)
306306

307-
def test_flag_restriction(self):
308-
with self.assertRaises(SystemExit):
309-
integration.run_synthetic(
310-
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
311-
extra_flags=['-resnet_version', '1', '-dtype', 'fp16']
312-
)
313-
314307

315308
if __name__ == '__main__':
316309
tf.test.main()

official/resnet/resnet_run_loop.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: dis
138138
# Functions for running training/eval/validation loops for the model.
139139
################################################################################
140140
def learning_rate_with_decay(
141-
batch_size, batch_denom, num_images, boundary_epochs, decay_rates):
141+
batch_size, batch_denom, num_images, boundary_epochs, decay_rates,
142+
base_lr=0.1, warmup=False):
142143
"""Get a learning rate that decays step-wise as training progresses.
143144
144145
Args:
@@ -152,13 +153,14 @@ def learning_rate_with_decay(
152153
decay_rates: list of floats representing the decay rates to be used
153154
for scaling the learning rate. It should have one more element
154155
than `boundary_epochs`, and all elements should have the same type.
155-
156+
base_lr: Initial learning rate scaled based on batch_denom.
157+
warmup: Run a 5 epoch warmup to the initial lr.
156158
Returns:
157159
Returns a function that takes a single argument - the number of batches
158160
trained so far (global_step)- and returns the learning rate to be used
159161
for training the next batch.
160162
"""
161-
initial_learning_rate = 0.1 * batch_size / batch_denom
163+
initial_learning_rate = base_lr * batch_size / batch_denom
162164
batches_per_epoch = num_images / batch_size
163165

164166
# Reduce the learning rate at certain epochs.
@@ -168,8 +170,15 @@ def learning_rate_with_decay(
168170
vals = [initial_learning_rate * decay for decay in decay_rates]
169171

170172
def learning_rate_fn(global_step):
171-
global_step = tf.cast(global_step, tf.int32)
172-
return tf.train.piecewise_constant(global_step, boundaries, vals)
173+
"""Builds scaled learning rate function with 5 epoch warm up."""
174+
lr = tf.train.piecewise_constant(global_step, boundaries, vals)
175+
if warmup:
176+
warmup_steps = int(batches_per_epoch * 5)
177+
warmup_lr = (
178+
initial_learning_rate * tf.cast(global_step, tf.float32) / tf.cast(
179+
warmup_steps, tf.float32))
180+
return tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr)
181+
return lr
173182

174183
return learning_rate_fn
175184

@@ -499,12 +508,3 @@ def define_resnet_flags(resnet_size_choices=None):
499508
flags.DEFINE_string(**choice_kwargs)
500509
else:
501510
flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs)
502-
503-
# The current implementation of ResNet v1 is numerically unstable when run
504-
# with fp16 and will produce NaN errors soon after training begins.
505-
msg = ('ResNet version 1 is not currently supported with fp16. '
506-
'Please use version 2 instead.')
507-
@flags.multi_flags_validator(['dtype', 'resnet_version'], message=msg)
508-
def _forbid_v1_fp16(flag_values): # pylint: disable=unused-variable
509-
return (flags_core.DTYPE_MAP[flag_values['dtype']][0] != tf.float16 or
510-
flag_values['resnet_version'] != '1')

0 commit comments

Comments
 (0)