Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0b74bfd

Browse files
ptrblckmcarilli
authored andcommitted
Disable tests for mixed opt_levels, add bitwise accurate test of parameters (NVIDIA#520)
* increase atol for Half-Float comparison to 1.5e-4 * disable tests for different opt_levels * reset atol * add bitwise accurate comparison
1 parent 03421e8 commit 0b74bfd

1 file changed

Lines changed: 16 additions & 51 deletions

File tree

tests/L0/run_amp/test_checkpointing.py

Lines changed: 16 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,18 @@ def train_step(self, model, optimizer, data, loss_ids):
5757
optimizer.step()
5858
return output
5959

60-
def compare_models(self, modelA, modelB):
60+
def compare_models(self, modelA, modelB, test_setup=''):
6161
state_dictA = modelA.state_dict()
6262
state_dictB = modelB.state_dict()
6363
self.assertEqual(len(state_dictA), len(state_dictB),
64-
'state_dicts have different lengths')
64+
'state_dicts have different lengths' + test_setup)
6565
for key in state_dictA:
6666
paramA = state_dictA[key]
6767
paramB = state_dictB[key]
68-
self.assertTrue(torch.allclose(paramA.float(), paramB.float(), rtol=0, atol=1e-4),
69-
msg='Parameters in state_dicts not equal.' +
70-
'key: {}\nparam: {}\nrestored: {}\ndiff: {}'.format(
71-
key, paramA, paramB, paramA - paramB))
68+
self.assertTrue((paramA==paramB).all(),
69+
msg='Parameters in state_dices not equal.' +
70+
'key: {}\nparam: {}\nrestored: {}\ndiff: {} for {}'.format(
71+
key, paramA, paramB, paramA - paramB, test_setup))
7272

7373
def test_restoring(self):
7474
nb_epochs = 10
@@ -77,11 +77,11 @@ def test_restoring(self):
7777
for res_opt_level in self.test_opt_levels:
7878
for amp_before_load in [True, False]:
7979
for num_losses in range(1, 3):
80-
# print('#' * 75 + '\n' + \
81-
# f'opt_level {opt_level}\n' + \
82-
# f'restore_opt_level {res_opt_level}\n' + \
83-
# f'amp_before_load {amp_before_load}\n' + \
84-
# f'num_losses {num_losses}\n')
80+
test_setup = ('#' * 75 + '\n' + \
81+
f'opt_level {opt_level}\n' + \
82+
f'restore_opt_level {res_opt_level}\n' + \
83+
f'amp_before_load {amp_before_load}\n' + \
84+
f'num_losses {num_losses}\n')
8585

8686
self.seed()
8787

@@ -154,47 +154,12 @@ def test_restoring(self):
154154
range(num_losses, num_losses*2))
155155
self.assertTrue(
156156
torch.allclose(output.float(), restore_output.float()),
157-
'Output of reference and restored models differ')
158-
self.compare_models(model, restore_model)
157+
'Output of reference and restored models differ for ' + test_setup)
158+
self.compare_models(model, restore_model, test_setup)
159159
# if opt_level != res_opt_level
160160
else:
161-
# Only check state_dict
162-
checkpoint = {
163-
'model': model.state_dict(),
164-
'optimizer': optimizer.state_dict(),
165-
'amp': amp.state_dict()
166-
}
167-
# Check state_dict for FP32 tensors
168-
self.check_state_dict_fp32(checkpoint['model'])
169-
170-
# Restore model
171-
restore_model = MyModel().to('cuda')
172-
restore_optimizer = optim.SGD(
173-
restore_model.parameters(),
174-
lr=self.initial_lr)
175-
176-
if amp_before_load:
177-
restore_model, restore_optimizer = amp.initialize(
178-
restore_model,
179-
restore_optimizer,
180-
opt_level=res_opt_level,
181-
num_losses=num_losses,
182-
verbosity=0)
183-
184-
restore_model.load_state_dict(checkpoint['model'])
185-
restore_optimizer.load_state_dict(checkpoint['optimizer'])
186-
# FIXME: We cannot test the amp.state_dict in the same script
187-
# amp.load_state_dict(checkpoint['amp'])
188-
189-
if not amp_before_load:
190-
restore_model, restore_optimizer = amp.initialize(
191-
restore_model,
192-
restore_optimizer,
193-
opt_level=res_opt_level,
194-
num_losses=num_losses,
195-
verbosity=0)
196-
197-
self.compare_models(model, restore_model)
161+
# skip tests for different opt_levels
162+
continue
198163

199164
def test_loss_scale_decrease(self):
200165
num_losses = 3
@@ -207,7 +172,7 @@ def test_loss_scale_decrease(self):
207172
model = MyModel().to('cuda')
208173

209174
optimizer = optim.SGD(model.parameters(),
210-
lr=1e-3)#self.initial_lr)
175+
lr=self.initial_lr)
211176

212177
model, optimizer = amp.initialize(
213178
model, optimizer, opt_level=opt_level, num_losses=num_losses,

0 commit comments

Comments
 (0)