Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1046a42

Browse files
fixed first layer masks bug, added pool arg
1 parent 85b7241 commit 1046a42

File tree

2 files changed

+125
-94
lines changed

2 files changed

+125
-94
lines changed

main.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
parser.add_argument('--dropout', type=float, default=0.5, metavar='', help='dropout parameter')
5757
parser.add_argument('--net-type', type=str, default='resnet18', metavar='', help='type of network')
5858
parser.add_argument('--act', type=str, default='relu', metavar='', help='activation function (for both perturb and conv layers)')
59+
parser.add_argument('--pool_type', type=str, default='max', metavar='', help='pooling function (max or avg)')
5960

6061
# ======================== Training Settings =======================================
6162
parser.add_argument('--batch-size', type=int, default=64, metavar='', help='batch size for training')
@@ -90,14 +91,15 @@ def __init__(self, args):
9091
self.nmasks = args.nmasks
9192
self.unique_masks = args.unique_masks
9293
self.filter_size = args.filter_size
94+
self.first_filter_size = args.first_filter_size
9395
self.scale_noise = args.scale_noise
9496
self.noise_type = args.noise_type
9597
self.act = args.act
9698
self.use_act = args.use_act
9799
self.dropout = args.dropout
98-
self.first_filter_size = args.first_filter_size
99100
self.train_masks = args.train_masks
100101
self.debug = args.debug
102+
self.pool_type = args.pool_type
101103

102104
if self.dataset_train_name.startswith("CIFAR"):
103105
self.input_size = 32
@@ -123,14 +125,16 @@ def __init__(self, args):
123125
unique_masks=self.unique_masks,
124126
level=self.level,
125127
filter_size=self.filter_size,
128+
first_filter_size=self.first_filter_size,
126129
act=self.act,
127130
scale_noise=self.scale_noise,
128131
noise_type=self.noise_type,
129132
use_act=self.use_act,
130133
dropout=self.dropout,
131-
first_filter_size=self.first_filter_size,
132134
train_masks=self.train_masks,
133-
debug=self.debug
135+
pool_type=self.pool_type,
136+
debug=self.debug,
137+
input_size=self.input_size
134138
)
135139

136140
self.loss_fn = nn.CrossEntropyLoss()
@@ -224,7 +228,9 @@ def test(self, dataloader):
224228

225229
return np.mean(losses), np.mean(accuracies)
226230

231+
print('\n\n****** Creating {} model ******\n\n'.format(args.net_type))
227232
setup = Model(args)
233+
print('\n\n****** Preparing {} dataset *******\n\n'.format(args.dataset_train))
228234
dataloader = Dataloader(args, setup.input_size)
229235
loader_train, loader_test = dataloader.create()
230236

@@ -255,10 +261,6 @@ def test(self, dataloader):
255261
init_epoch += 1
256262

257263

258-
print('\n\n****** Model Configuration ******\n\n')
259-
for arg in vars(args):
260-
print(arg, getattr(args, arg))
261-
262264
print('\n\n****** Model Graph ******\n\n')
263265
for arg in vars(model):
264266
print(arg, getattr(model, arg))
@@ -270,12 +272,19 @@ def test(self, dataloader):
270272

271273
print('\n\nModel: {}, {:.2f}M parameters\n\n'.format(args.net_type, sum(p.numel() for p in model.parameters()) / 1000000.))
272274

273-
if model.train_masks:
274-
msg = 'also training noise masks values'
275+
print('\n\n****** Model Configuration ******\n\n')
276+
for arg in vars(args):
277+
print(arg, getattr(args, arg))
278+
279+
if args.net_type != 'resnet18' and args.net_type != 'noiseresnet18' and (args.first_filter_size == 0 or args.filter_size == 0):
280+
if args.train_masks:
281+
msg = '(also training noise masks values)'
282+
else:
283+
msg = '(noise masks are fixed)'
275284
else:
276-
msg = 'moise masks are fixed'
285+
msg = ''
277286

278-
print('\n\nTraining Model {}\n\n'.format(msg))
287+
print('\n\nTraining {} model {}\n\n'.format(args.net_type, msg))
279288

280289
for epoch in range(init_epoch, args.nepochs, 1):
281290

0 commit comments

Comments
 (0)