@@ -121,17 +121,16 @@ def __init__(self,
121121 print ("FP16_Optimizer processing param group {}:" .format (i ))
122122 fp16_params_this_group = []
123123 fp32_params_this_group = []
124- master_params_this_group = []
125124 fp32_from_fp16_params_this_group = []
126- for param in param_group ['params' ]:
125+ for i , param in enumerate ( param_group ['params' ]) :
127126 if param .requires_grad :
128127 if param .type () == 'torch.cuda.HalfTensor' :
129128 print ("FP16_Optimizer received torch.cuda.HalfTensor with {}"
130129 .format (param .size ()))
131130 fp16_params_this_group .append (param )
132131 master_param = param .detach ().clone ().float ()
133132 master_param .requires_grad = True
134- master_params_this_group . append ( master_param )
133+ param_group [ 'params' ][ i ] = master_param
135134 fp32_from_fp16_params_this_group .append (master_param )
136135 # Reset existing state dict key to the new master param.
137136 # We still need to recast per-param state tensors, if any, to FP32.
@@ -141,14 +140,12 @@ def __init__(self,
141140 print ("FP16_Optimizer received torch.cuda.FloatTensor with {}"
142141 .format (param .size ()))
143142 fp32_params_this_group .append (param )
144- master_params_this_group . append ( param )
143+ param_group [ 'params' ][ i ] = param
145144 else :
146145 raise TypeError ("Wrapped parameters must be either "
147146 "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
148147 "Received {}" .format (param .type ()))
149148
150- param_group ['params' ] = master_params_this_group
151-
152149 self .fp16_groups .append (fp16_params_this_group )
153150 self .fp32_from_fp16_groups .append (fp32_from_fp16_params_this_group )
154151 self .fp32_from_fp32_groups .append (fp32_params_this_group )
0 commit comments