@@ -138,7 +138,8 @@ def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: dis
138
138
# Functions for running training/eval/validation loops for the model.
139
139
################################################################################
140
140
def learning_rate_with_decay (
141
- batch_size , batch_denom , num_images , boundary_epochs , decay_rates ):
141
+ batch_size , batch_denom , num_images , boundary_epochs , decay_rates ,
142
+ base_lr = 0.1 , warmup = False ):
142
143
"""Get a learning rate that decays step-wise as training progresses.
143
144
144
145
Args:
@@ -152,13 +153,14 @@ def learning_rate_with_decay(
152
153
decay_rates: list of floats representing the decay rates to be used
153
154
for scaling the learning rate. It should have one more element
154
155
than `boundary_epochs`, and all elements should have the same type.
155
-
156
+ base_lr: Initial learning rate scaled based on batch_denom.
157
+ warmup: Run a 5 epoch warmup to the initial lr.
156
158
Returns:
157
159
Returns a function that takes a single argument - the number of batches
158
160
trained so far (global_step)- and returns the learning rate to be used
159
161
for training the next batch.
160
162
"""
161
- initial_learning_rate = 0.1 * batch_size / batch_denom
163
+ initial_learning_rate = base_lr * batch_size / batch_denom
162
164
batches_per_epoch = num_images / batch_size
163
165
164
166
# Reduce the learning rate at certain epochs.
@@ -168,8 +170,15 @@ def learning_rate_with_decay(
168
170
vals = [initial_learning_rate * decay for decay in decay_rates ]
169
171
170
172
def learning_rate_fn (global_step ):
171
- global_step = tf .cast (global_step , tf .int32 )
172
- return tf .train .piecewise_constant (global_step , boundaries , vals )
173
+ """Builds scaled learning rate function with 5 epoch warm up."""
174
+ lr = tf .train .piecewise_constant (global_step , boundaries , vals )
175
+ if warmup :
176
+ warmup_steps = int (batches_per_epoch * 5 )
177
+ warmup_lr = (
178
+ initial_learning_rate * tf .cast (global_step , tf .float32 ) / tf .cast (
179
+ warmup_steps , tf .float32 ))
180
+ return tf .cond (global_step < warmup_steps , lambda : warmup_lr , lambda : lr )
181
+ return lr
173
182
174
183
return learning_rate_fn
175
184
@@ -499,12 +508,3 @@ def define_resnet_flags(resnet_size_choices=None):
499
508
flags .DEFINE_string (** choice_kwargs )
500
509
else :
501
510
flags .DEFINE_enum (enum_values = resnet_size_choices , ** choice_kwargs )
502
-
503
- # The current implementation of ResNet v1 is numerically unstable when run
504
- # with fp16 and will produce NaN errors soon after training begins.
505
- msg = ('ResNet version 1 is not currently supported with fp16. '
506
- 'Please use version 2 instead.' )
507
- @flags .multi_flags_validator (['dtype' , 'resnet_version' ], message = msg )
508
- def _forbid_v1_fp16 (flag_values ): # pylint: disable=unused-variable
509
- return (flags_core .DTYPE_MAP [flag_values ['dtype' ]][0 ] != tf .float16 or
510
- flag_values ['resnet_version' ] != '1' )
0 commit comments