@@ -57,18 +57,18 @@ def train_step(self, model, optimizer, data, loss_ids):
5757 optimizer .step ()
5858 return output
5959
60- def compare_models (self , modelA , modelB ):
60+ def compare_models (self , modelA , modelB , test_setup = '' ):
6161 state_dictA = modelA .state_dict ()
6262 state_dictB = modelB .state_dict ()
6363 self .assertEqual (len (state_dictA ), len (state_dictB ),
64- 'state_dicts have different lengths' )
64+ 'state_dicts have different lengths' + test_setup )
6565 for key in state_dictA :
6666 paramA = state_dictA [key ]
6767 paramB = state_dictB [key ]
68- self .assertTrue (torch . allclose (paramA . float (), paramB . float (), rtol = 0 , atol = 1e-4 ),
69- msg = 'Parameters in state_dicts not equal.' +
70- 'key: {}\n param: {}\n restored: {}\n diff: {}' .format (
71- key , paramA , paramB , paramA - paramB ))
68+ self .assertTrue ((paramA == paramB ). all ( ),
69+ msg = 'Parameters in state_dices not equal.' +
70+ 'key: {}\n param: {}\n restored: {}\n diff: {} for {}' .format (
71+ key , paramA , paramB , paramA - paramB , test_setup ))
7272
7373 def test_restoring (self ):
7474 nb_epochs = 10
@@ -77,11 +77,11 @@ def test_restoring(self):
7777 for res_opt_level in self .test_opt_levels :
7878 for amp_before_load in [True , False ]:
7979 for num_losses in range (1 , 3 ):
80- # print ('#' * 75 + '\n' + \
81- # f'opt_level {opt_level}\n' + \
82- # f'restore_opt_level {res_opt_level}\n' + \
83- # f'amp_before_load {amp_before_load}\n' + \
84- # f'num_losses {num_losses}\n')
80+ test_setup = ('#' * 75 + '\n ' + \
81+ f'opt_level { opt_level } \n ' + \
82+ f'restore_opt_level { res_opt_level } \n ' + \
83+ f'amp_before_load { amp_before_load } \n ' + \
84+ f'num_losses { num_losses } \n ' )
8585
8686 self .seed ()
8787
@@ -154,47 +154,12 @@ def test_restoring(self):
154154 range (num_losses , num_losses * 2 ))
155155 self .assertTrue (
156156 torch .allclose (output .float (), restore_output .float ()),
157- 'Output of reference and restored models differ' )
158- self .compare_models (model , restore_model )
157+ 'Output of reference and restored models differ for ' + test_setup )
158+ self .compare_models (model , restore_model , test_setup )
159159 # if opt_level != res_opt_level
160160 else :
161- # Only check state_dict
162- checkpoint = {
163- 'model' : model .state_dict (),
164- 'optimizer' : optimizer .state_dict (),
165- 'amp' : amp .state_dict ()
166- }
167- # Check state_dict for FP32 tensors
168- self .check_state_dict_fp32 (checkpoint ['model' ])
169-
170- # Restore model
171- restore_model = MyModel ().to ('cuda' )
172- restore_optimizer = optim .SGD (
173- restore_model .parameters (),
174- lr = self .initial_lr )
175-
176- if amp_before_load :
177- restore_model , restore_optimizer = amp .initialize (
178- restore_model ,
179- restore_optimizer ,
180- opt_level = res_opt_level ,
181- num_losses = num_losses ,
182- verbosity = 0 )
183-
184- restore_model .load_state_dict (checkpoint ['model' ])
185- restore_optimizer .load_state_dict (checkpoint ['optimizer' ])
186- # FIXME: We cannot test the amp.state_dict in the same script
187- # amp.load_state_dict(checkpoint['amp'])
188-
189- if not amp_before_load :
190- restore_model , restore_optimizer = amp .initialize (
191- restore_model ,
192- restore_optimizer ,
193- opt_level = res_opt_level ,
194- num_losses = num_losses ,
195- verbosity = 0 )
196-
197- self .compare_models (model , restore_model )
161+ # skip tests for different opt_levels
162+ continue
198163
199164 def test_loss_scale_decrease (self ):
200165 num_losses = 3
@@ -207,7 +172,7 @@ def test_loss_scale_decrease(self):
207172 model = MyModel ().to ('cuda' )
208173
209174 optimizer = optim .SGD (model .parameters (),
210- lr = 1e-3 ) # self.initial_lr)
175+ lr = self .initial_lr )
211176
212177 model , optimizer = amp .initialize (
213178 model , optimizer , opt_level = opt_level , num_losses = num_losses ,
0 commit comments