File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -131,10 +131,10 @@ def step(self, closure=None):
131131 group ['exp_avg_sq' ][1 ] = torch .cuda .FloatTensor (len (g_32 )).contiguous ().fill_ (0 )
132132 else : # init with first step norm, so first blend have no effect
133133 if group ['norm_type' ] == 0 :
134- v_16 = [torch .max (torch .abs (g )).item () for g in g_16 ]
134+ v_16 = [torch .max (torch .abs (g . to ( torch . float32 ) )).item () for g in g_16 ]
135135 v_32 = [torch .max (torch .abs (g )).item () for g in g_32 ]
136136 elif group ['norm_type' ] == 2 :
137- v_16 = [torch .sum (torch .pow (g , 2 )).sqrt ().item () for g in g_16 ]
137+ v_16 = [torch .sum (torch .pow (g . to ( torch . float32 ) , 2 )).sqrt ().item () for g in g_16 ]
138138 v_32 = [torch .sum (torch .pow (g , 2 )).sqrt ().item () for g in g_32 ]
139139 else :
140140 raise RuntimeError ('FusedNovoGrad only support l2/inf norm now.' )
You can’t perform that action at this time.
0 commit comments