Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ed8236f

Browse files
Fix for unscale usage in fp16_utils.FP16_Optimizer
1 parent d137b80 commit ed8236f

8 files changed

Lines changed: 52 additions & 30 deletions

File tree

apex/amp/_initialize.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import torch
22
from torch._six import container_abcs, string_classes
33
import functools
4-
from apex.fp16_utils import convert_network
54
from ._amp_state import _amp_state
65
from .scaler import LossScaler
6+
from apex.fp16_utils import convert_network
77
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
88
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
99
from ..optimizers import FusedAdam
10+
from ..parallel import DistributedDataParallel as apex_DDP
1011

1112

1213
def to_type(dtype, t):
@@ -71,7 +72,7 @@ def check_optimizers(optimizers):
7172
bad_optim_type = None
7273
if isinstance(optim, FP16_Optimizer_general):
7374
bad_optim_type = "apex.fp16_utils.FP16_Optimizer"
74-
if isinstance(model, FP16_Optimizer_for_fused):
75+
if isinstance(optim, FP16_Optimizer_for_fused):
7576
bad_optim_type = "apex.optimizers.FP16_Optimizer"
7677
if bad_optim_type is not None:
7778
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) +
@@ -81,7 +82,7 @@ def check_optimizers(optimizers):
8182
"soon). You should not manually wrap your optimizer in either \n"
8283
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer. \n"
8384
"amp.initialize will take care of that for you (if necessary) based \n"
84-
"on the specified opt_level (and optional overridden properties)."
85+
"on the specified opt_level (and optional overridden properties).")
8586

8687

8788
def _initialize(models, optimizers, properties):
@@ -141,9 +142,11 @@ def new_fwd(*args, **kwargs):
141142
if isinstance(optimizer, FusedAdam):
142143
optimizers[i] = wrap_fused_adam(optimizer, properties)
143144
if properties.loss_scale == "dynamic":
144-
optimizers[i] = FP16_Optimizer_general(optimizers[i], dynamic_loss_scale=True)
145+
optimizers[i] = FP16_Optimizer_general(optimizers[i],
146+
dynamic_loss_scale=True)
145147
else:
146-
optimizers[i] = FP16_Optimizer(optimizers[i], static_loss_scale=properties.loss_scale)
148+
optimizers[i] = FP16_Optimizer_general(optimizers[i],
149+
static_loss_scale=properties.loss_scale)
147150
else:
148151
for optimizer in optimizers:
149152
optimizer.loss_scaler = LossScaler(properties.loss_scale)

apex/amp/frontend.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __call__(self, properties):
9191
properties.opt_level = "O2"
9292
properties.cast_model_type = torch.float16
9393
properties.patch_torch_functions = False
94-
properties.keep_batchnorm_fp32 = torch.float32
94+
properties.keep_batchnorm_fp32 = True
9595
properties.master_weights = True
9696
properties.loss_scale = "dynamic"
9797
properties.fused_optimizer = False
@@ -174,6 +174,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
174174
enable_ddp_interop=None):
175175
"""
176176
if not enabled:
177+
_amp_state.opt_properties = Properties()
177178
return models, optimizers
178179

179180
if opt_level not in opt_levels:
@@ -186,7 +187,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
186187
print("Defaults for this optimization level are:")
187188
print(_amp_state.opt_properties.options)
188189
for k, v in _amp_state.opt_properties.options.items():
189-
print("{:20} : {}".format(k, v))
190+
print("{:22} : {}".format(k, v))
190191

191192
print("Processing user overrides (additional kwargs that are not None)...")
192193
for k, v in kwargs.items():
@@ -197,7 +198,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
197198

198199
print("After processing overrides, optimization options are:")
199200
for k, v in _amp_state.opt_properties.options.items():
200-
print("{:20} : {}".format(k, v))
201+
print("{:22} : {}".format(k, v))
201202

202203
return _initialize(models, optimizers, _amp_state.opt_properties)
203204

@@ -228,7 +229,7 @@ def check_option_consistency(enabled=True,
228229
print("Selected optimization level {}", opt_levels[opt_level].brief)
229230
print("Defaults for this optimization level are:")
230231
for k, v in opt_properties.options:
231-
print("{:20} : {}".format(k, v))
232+
print("{:22} : {}".format(k, v))
232233

233234
print("Processing user overrides (additional kwargs that are not None)...")
234235
for k, v in kwargs:
@@ -239,4 +240,4 @@ def check_option_consistency(enabled=True,
239240

240241
print("After processing overrides, optimization options are:")
241242
for k, v in opt_properties.options:
242-
print("{:20} : {}".format(k, v))
243+
print("{:22} : {}".format(k, v))

apex/amp/handle.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ def scale_loss(loss,
4545
if isinstance(optimizer, FP16_Optimizer):
4646
optimizer.update_master_grads()
4747
else:
48+
optimizer.loss_scaler.clear_overflow_state()
4849
optimizer.loss_scaler.unscale(
4950
iter_params(optimizer.param_groups),
5051
iter_params(optimizer.param_groups),
5152
loss_scale)
52-
# If overflow_check_on_cpu is False, should_skip will always be False.
53+
# In the future, once I have fused optimizers that enable sync-free dynamic loss scaling,
54+
# should_skip will always be False.
5355
should_skip = optimizer.loss_scaler.update_scale()
5456
if should_skip:
5557
optimizer_step = optimizer.step
@@ -101,6 +103,7 @@ def scale_loss(self, loss, optimizer):
101103
loss_scale = self._default_scaler.loss_scale()
102104
yield loss * loss_scale
103105

106+
self._default_scaler.clear_overflow_state()
104107
self._default_scaler.unscale(
105108
iter_params(optimizer.param_groups),
106109
iter_params(optimizer.param_groups),

apex/amp/opt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def scale_loss(self, loss):
3737
loss_scale = self._cur_loss_scaler().loss_scale()
3838
yield loss * loss_scale
3939

40+
self._cur_loss_scaler().clear_overflow_state()
4041
self._cur_loss_scaler().unscale(
4142
iter_params(self._optimizer.param_groups),
4243
iter_params(self._optimizer.param_groups),

apex/amp/scaler.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,18 @@
55

66
# from apex_C import scale_check_overflow
77

8-
def scale_check_overflow_python(model_grad, scale, master_grad):
8+
def scale_check_overflow_python(model_grad, scale, master_grad, check_overflow=False):
99
# Exception handling for 18.04 compatibility
10-
try:
10+
if check_overflow:
1111
cpu_sum = float(model_grad.float().sum())
12-
except RuntimeError as instance:
13-
if "value cannot be converted" not in instance.args[0]:
14-
raise
15-
return True
16-
else:
1712
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
1813
return True
19-
if master_grad is not model_grad:
20-
master_grad.copy_(model_grad)
21-
if scale != 1.0:
22-
master_grad.mul_(scale)
23-
return False
14+
15+
if master_grad is not model_grad: # copy_ probably internally short-circuits this
16+
master_grad.copy_(model_grad)
17+
if scale != 1.0:
18+
master_grad.mul_(scale)
19+
return False
2420

2521
class LossScaler(object):
2622
warned_no_fused_kernel = False
@@ -73,12 +69,21 @@ def unscale_grads_python(self, model_grads, master_grads, scale):
7369
self._has_overflow = scale_check_overflow_python(
7470
model,
7571
1./scale,
76-
master)
72+
master,
73+
self.dynamic)
7774
if self._has_overflow and self.dynamic:
7875
break
7976

80-
def unscale(self, model_params, master_params, scale):
77+
def clear_overflow_state(self):
8178
self._has_overflow = False
79+
if self.has_fused_kernel:
80+
self._overflow_buf.zero_()
81+
82+
def unscale(self, model_params, master_params, scale):
83+
# torch.cuda.nvtx.range_push("unscale")
84+
if self._has_overflow:
85+
# torch.cuda.nvtx.range_pop()
86+
return
8287

8388
# Lots of defensive list processing going on here. Way more less efficient than
8489
# consuming the iterator directly. Need to examine Python overhead.
@@ -112,12 +117,12 @@ def unscale(self, model_params, master_params, scale):
112117
# Warning: setting this to True unconditionally allows the possibility of an escape
113118
# if never-before-seen non-fp32 grads are created in some later iteration.
114119
LossScaler.warned_unscaling_non_fp32_grad = True
115-
self._overflow_buf.zero_()
116120
# handle case of opt_level O1 and loss_scale 1.0. There's also some
117121
# special-cased yields in scale_loss to potentially short-circuit earlier.
118122
# TODO: Profile and find out if all the O(N) list processing in unscale()
119123
# is a bottleneck.
120124
if scale == 1.0 and all_same and not self.dynamic:
125+
# torch.cuda.nvtx.range_pop()
121126
return
122127
else:
123128
multi_tensor_applier(
@@ -128,12 +133,14 @@ def unscale(self, model_params, master_params, scale):
128133
else:
129134
self.unscale_grads_python(model_grads, master_grads, scale)
130135

131-
# Break into multiple param groups so unscale() can be called more that once before updating.
132-
def update_scale(self):
133136
# If the fused kernel is available, we only need one D2H memcopy and sync.
134137
if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
135138
self._has_overflow = self._overflow_buf.item()
136139

140+
# torch.cuda.nvtx.range_pop()
141+
142+
# Separate so unscale() can be called more that once before updating.
143+
def update_scale(self):
137144
if self._has_overflow and self.dynamic:
138145
should_skip = True
139146
self._loss_scale /= 2.

apex/fp16_utils/fp16_optimizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,9 @@ def closure():
393393
if closure is not None:
394394
retval = self._step_with_closure(closure)
395395
else:
396+
# torch.cuda.nvtx.range_push("pytorch optimizer step")
396397
retval = self.optimizer.step()
398+
# torch.cuda.nvtx.range_pop()
397399

398400
self._master_params_to_model_params()
399401

@@ -502,6 +504,7 @@ def backward(self, loss, update_master_grads=True, retain_graph=False):
502504
self.update_master_grads()
503505

504506
def update_master_grads(self):
507+
# torch.cuda.nvtx.range_push("update_master_grads")
505508
"""
506509
Copy the ``.grad`` attribute from stored references to fp16 parameters to
507510
the ``.grad`` attribute of the fp32 master parameters that are directly
@@ -514,6 +517,7 @@ def update_master_grads(self):
514517
# self._model_grads_to_master_grads()
515518
# self._downscale_master()
516519
# Use the one-shot multi-tensor apply kernel
520+
self.loss_scaler.clear_overflow_state()
517521
if len(self.all_fp16_params) > 0:
518522
# print("Model grads before")
519523
# print([param.grad.data for param in self.all_fp16_params])
@@ -534,6 +538,7 @@ def update_master_grads(self):
534538
# print([param.grad.data for param in self.all_fp32_from_fp32_params])
535539
# quit()
536540
self.overflow = self.loss_scaler.update_scale()
541+
# torch.cuda.nvtx.range_pop()
537542

538543

539544
def inspect_master_grad_data(self):

examples/imagenet/main_amp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def fast_collate(batch):
9595
cudnn.benchmark = False
9696
cudnn.deterministic = True
9797
torch.manual_seed(args.local_rank)
98+
torch.set_printoptions(precision=10)
9899

99100
# Initialize Amp
100101
amp_handle = amp.init(enabled=args.fp16)
@@ -337,7 +338,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
337338
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
338339
'Speed {3:.3f} ({4:.3f})\t'
339340
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
340-
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
341+
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
341342
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
342343
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
343344
epoch, i, len(train_loader),

examples/imagenet/main_fp16_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def fast_collate(batch):
9999
cudnn.benchmark = False
100100
cudnn.deterministic = True
101101
torch.manual_seed(args.local_rank)
102+
torch.set_printoptions(precision=10)
102103

103104
def main():
104105
global best_prec1, args
@@ -344,7 +345,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
344345
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
345346
'Speed {3:.3f} ({4:.3f})\t'
346347
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
347-
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
348+
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
348349
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
349350
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
350351
epoch, i, len(train_loader),

0 commit comments

Comments
 (0)