55
66# from apex_C import scale_check_overflow
77
8- def scale_check_overflow_python (model_grad , scale , master_grad ):
8+ def scale_check_overflow_python (model_grad , scale , master_grad , check_overflow = False ):
99 # Exception handling for 18.04 compatibility
10- try :
10+ if check_overflow :
1111 cpu_sum = float (model_grad .float ().sum ())
12- except RuntimeError as instance :
13- if "value cannot be converted" not in instance .args [0 ]:
14- raise
15- return True
16- else :
1712 if cpu_sum == float ('inf' ) or cpu_sum == - float ('inf' ) or cpu_sum != cpu_sum :
1813 return True
19- if master_grad is not model_grad :
20- master_grad .copy_ (model_grad )
21- if scale != 1.0 :
22- master_grad .mul_ (scale )
23- return False
14+
15+ if master_grad is not model_grad : # copy_ probably internally short-circuits this
16+ master_grad .copy_ (model_grad )
17+ if scale != 1.0 :
18+ master_grad .mul_ (scale )
19+ return False
2420
2521class LossScaler (object ):
2622 warned_no_fused_kernel = False
@@ -73,12 +69,21 @@ def unscale_grads_python(self, model_grads, master_grads, scale):
7369 self ._has_overflow = scale_check_overflow_python (
7470 model ,
7571 1. / scale ,
76- master )
72+ master ,
73+ self .dynamic )
7774 if self ._has_overflow and self .dynamic :
7875 break
7976
80- def unscale (self , model_params , master_params , scale ):
77+ def clear_overflow_state (self ):
8178 self ._has_overflow = False
79+ if self .has_fused_kernel :
80+ self ._overflow_buf .zero_ ()
81+
82+ def unscale (self , model_params , master_params , scale ):
83+ # torch.cuda.nvtx.range_push("unscale")
84+ if self ._has_overflow :
85+ # torch.cuda.nvtx.range_pop()
86+ return
8287
8388 # Lots of defensive list processing going on here. Way more less efficient than
8489 # consuming the iterator directly. Need to examine Python overhead.
@@ -112,12 +117,12 @@ def unscale(self, model_params, master_params, scale):
112117 # Warning: setting this to True unconditionally allows the possibility of an escape
113118 # if never-before-seen non-fp32 grads are created in some later iteration.
114119 LossScaler .warned_unscaling_non_fp32_grad = True
115- self ._overflow_buf .zero_ ()
116120 # handle case of opt_level O1 and loss_scale 1.0. There's also some
117121 # special-cased yields in scale_loss to potentially short-circuit earlier.
118122 # TODO: Profile and find out if all the O(N) list processing in unscale()
119123 # is a bottleneck.
120124 if scale == 1.0 and all_same and not self .dynamic :
125+ # torch.cuda.nvtx.range_pop()
121126 return
122127 else :
123128 multi_tensor_applier (
@@ -128,12 +133,14 @@ def unscale(self, model_params, master_params, scale):
128133 else :
129134 self .unscale_grads_python (model_grads , master_grads , scale )
130135
131- # Break into multiple param groups so unscale() can be called more that once before updating.
132- def update_scale (self ):
133136 # If the fused kernel is available, we only need one D2H memcopy and sync.
134137 if LossScaler .has_fused_kernel and self .dynamic and not self ._has_overflow :
135138 self ._has_overflow = self ._overflow_buf .item ()
136139
140+ # torch.cuda.nvtx.range_pop()
141+
142+ # Separate so unscale() can be called more that once before updating.
143+ def update_scale (self ):
137144 if self ._has_overflow and self .dynamic :
138145 should_skip = True
139146 self ._loss_scale /= 2.
0 commit comments