Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ae921de

Browse files
Fixing FP16_Optimizer handling of LBFGS
1 parent d695b68 commit ae921de

1 file changed

Lines changed: 3 additions & 6 deletions

File tree

apex/fp16_utils/fp16_optimizer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,16 @@ def __init__(self,
121121
print("FP16_Optimizer processing param group {}:".format(i))
122122
fp16_params_this_group = []
123123
fp32_params_this_group = []
124-
master_params_this_group = []
125124
fp32_from_fp16_params_this_group = []
126-
for param in param_group['params']:
125+
for i, param in enumerate(param_group['params']):
127126
if param.requires_grad:
128127
if param.type() == 'torch.cuda.HalfTensor':
129128
print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
130129
.format(param.size()))
131130
fp16_params_this_group.append(param)
132131
master_param = param.detach().clone().float()
133132
master_param.requires_grad = True
134-
master_params_this_group.append(master_param)
133+
param_group['params'][i] = master_param
135134
fp32_from_fp16_params_this_group.append(master_param)
136135
# Reset existing state dict key to the new master param.
137136
# We still need to recast per-param state tensors, if any, to FP32.
@@ -141,14 +140,12 @@ def __init__(self,
141140
print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
142141
.format(param.size()))
143142
fp32_params_this_group.append(param)
144-
master_params_this_group.append(param)
143+
param_group['params'][i] = param
145144
else:
146145
raise TypeError("Wrapped parameters must be either "
147146
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
148147
"Received {}".format(param.type()))
149148

150-
param_group['params'] = master_params_this_group
151-
152149
self.fp16_groups.append(fp16_params_this_group)
153150
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
154151
self.fp32_from_fp32_groups.append(fp32_params_this_group)

0 commit comments

Comments
 (0)