5656parser .add_argument ('--dropout' , type = float , default = 0.5 , metavar = '' , help = 'dropout parameter' )
5757parser .add_argument ('--net-type' , type = str , default = 'resnet18' , metavar = '' , help = 'type of network' )
5858parser .add_argument ('--act' , type = str , default = 'relu' , metavar = '' , help = 'activation function (for both perturb and conv layers)' )
59+ parser .add_argument ('--pool_type' , type = str , default = 'max' , metavar = '' , help = 'pooling function (max or avg)' )
5960
6061# ======================== Training Settings =======================================
6162parser .add_argument ('--batch-size' , type = int , default = 64 , metavar = '' , help = 'batch size for training' )
@@ -90,14 +91,15 @@ def __init__(self, args):
9091 self .nmasks = args .nmasks
9192 self .unique_masks = args .unique_masks
9293 self .filter_size = args .filter_size
94+ self .first_filter_size = args .first_filter_size
9395 self .scale_noise = args .scale_noise
9496 self .noise_type = args .noise_type
9597 self .act = args .act
9698 self .use_act = args .use_act
9799 self .dropout = args .dropout
98- self .first_filter_size = args .first_filter_size
99100 self .train_masks = args .train_masks
100101 self .debug = args .debug
102+ self .pool_type = args .pool_type
101103
102104 if self .dataset_train_name .startswith ("CIFAR" ):
103105 self .input_size = 32
@@ -123,14 +125,16 @@ def __init__(self, args):
123125 unique_masks = self .unique_masks ,
124126 level = self .level ,
125127 filter_size = self .filter_size ,
128+ first_filter_size = self .first_filter_size ,
126129 act = self .act ,
127130 scale_noise = self .scale_noise ,
128131 noise_type = self .noise_type ,
129132 use_act = self .use_act ,
130133 dropout = self .dropout ,
131- first_filter_size = self .first_filter_size ,
132134 train_masks = self .train_masks ,
133- debug = self .debug
135+ pool_type = self .pool_type ,
136+ debug = self .debug ,
137+ input_size = self .input_size
134138 )
135139
136140 self .loss_fn = nn .CrossEntropyLoss ()
@@ -224,7 +228,9 @@ def test(self, dataloader):
224228
225229 return np .mean (losses ), np .mean (accuracies )
226230
231+ print ('\n \n ****** Creating {} model ******\n \n ' .format (args .net_type ))
227232setup = Model (args )
233+ print ('\n \n ****** Preparing {} dataset *******\n \n ' .format (args .dataset_train ))
228234dataloader = Dataloader (args , setup .input_size )
229235loader_train , loader_test = dataloader .create ()
230236
@@ -255,10 +261,6 @@ def test(self, dataloader):
255261 init_epoch += 1
256262
257263
258- print ('\n \n ****** Model Configuration ******\n \n ' )
259- for arg in vars (args ):
260- print (arg , getattr (args , arg ))
261-
262264print ('\n \n ****** Model Graph ******\n \n ' )
263265for arg in vars (model ):
264266 print (arg , getattr (model , arg ))
@@ -270,12 +272,19 @@ def test(self, dataloader):
270272
271273print ('\n \n Model: {}, {:.2f}M parameters\n \n ' .format (args .net_type , sum (p .numel () for p in model .parameters ()) / 1000000. ))
272274
273- if model .train_masks :
274- msg = 'also training noise masks values'
275+ print ('\n \n ****** Model Configuration ******\n \n ' )
276+ for arg in vars (args ):
277+ print (arg , getattr (args , arg ))
278+
279+ if args .net_type != 'resnet18' and args .net_type != 'noiseresnet18' and (args .first_filter_size == 0 or args .filter_size == 0 ):
280+ if args .train_masks :
281+ msg = '(also training noise masks values)'
282+ else :
283+ msg = '(noise masks are fixed)'
275284else :
276- msg = 'moise masks are fixed '
285+ msg = ''
277286
278- print ('\n \n Training Model {}\n \n ' .format (msg ))
287+ print ('\n \n Training {} model {}\n \n ' .format (args . net_type , msg ))
279288
280289for epoch in range (init_epoch , args .nepochs , 1 ):
281290
0 commit comments