Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e86f986

Browse files
committed
Put parser in a function to make script importable
1 parent 8d0deb0 commit e86f986

1 file changed

Lines changed: 63 additions & 60 deletions

File tree

examples/imagenet/main_amp.py

Lines changed: 63 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -25,54 +25,6 @@
2525
except ImportError:
2626
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
2727

28-
model_names = sorted(name for name in models.__dict__
29-
if name.islower() and not name.startswith("__")
30-
and callable(models.__dict__[name]))
31-
32-
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
33-
parser.add_argument('data', metavar='DIR',
34-
help='path to dataset')
35-
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
36-
choices=model_names,
37-
help='model architecture: ' +
38-
' | '.join(model_names) +
39-
' (default: resnet18)')
40-
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
41-
help='number of data loading workers (default: 4)')
42-
parser.add_argument('--epochs', default=90, type=int, metavar='N',
43-
help='number of total epochs to run')
44-
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
45-
help='manual epoch number (useful on restarts)')
46-
parser.add_argument('-b', '--batch-size', default=256, type=int,
47-
metavar='N', help='mini-batch size per process (default: 256)')
48-
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
49-
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
50-
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
51-
help='momentum')
52-
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
53-
metavar='W', help='weight decay (default: 1e-4)')
54-
parser.add_argument('--print-freq', '-p', default=10, type=int,
55-
metavar='N', help='print frequency (default: 10)')
56-
parser.add_argument('--resume', default='', type=str, metavar='PATH',
57-
help='path to latest checkpoint (default: none)')
58-
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
59-
help='evaluate model on validation set')
60-
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
61-
help='use pre-trained model')
62-
63-
parser.add_argument('--prof', default=-1, type=int,
64-
help='Only run 10 iterations for profiling.')
65-
parser.add_argument('--deterministic', action='store_true')
66-
67-
parser.add_argument("--local_rank", default=0, type=int)
68-
parser.add_argument('--sync_bn', action='store_true',
69-
help='enabling apex sync BN.')
70-
71-
parser.add_argument('--opt-level', type=str)
72-
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
73-
parser.add_argument('--loss-scale', type=str, default=None)
74-
75-
cudnn.benchmark = True
7628

7729
def fast_collate(batch):
7830
imgs = [img[0] for img in batch]
@@ -90,24 +42,75 @@ def fast_collate(batch):
9042

9143
return tensor, targets
9244

93-
best_prec1 = 0
94-
args = parser.parse_args()
9545

96-
print("opt_level = {}".format(args.opt_level))
97-
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
98-
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
99-
100-
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
46+
def parse():
47+
model_names = sorted(name for name in models.__dict__
48+
if name.islower() and not name.startswith("__")
49+
and callable(models.__dict__[name]))
10150

102-
if args.deterministic:
103-
cudnn.benchmark = False
104-
cudnn.deterministic = True
105-
torch.manual_seed(args.local_rank)
106-
torch.set_printoptions(precision=10)
51+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
52+
parser.add_argument('data', metavar='DIR',
53+
help='path to dataset')
54+
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
55+
choices=model_names,
56+
help='model architecture: ' +
57+
' | '.join(model_names) +
58+
' (default: resnet18)')
59+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
60+
help='number of data loading workers (default: 4)')
61+
parser.add_argument('--epochs', default=90, type=int, metavar='N',
62+
help='number of total epochs to run')
63+
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
64+
help='manual epoch number (useful on restarts)')
65+
parser.add_argument('-b', '--batch-size', default=256, type=int,
66+
metavar='N', help='mini-batch size per process (default: 256)')
67+
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
68+
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
69+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
70+
help='momentum')
71+
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
72+
metavar='W', help='weight decay (default: 1e-4)')
73+
parser.add_argument('--print-freq', '-p', default=10, type=int,
74+
metavar='N', help='print frequency (default: 10)')
75+
parser.add_argument('--resume', default='', type=str, metavar='PATH',
76+
help='path to latest checkpoint (default: none)')
77+
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
78+
help='evaluate model on validation set')
79+
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
80+
help='use pre-trained model')
81+
82+
parser.add_argument('--prof', default=-1, type=int,
83+
help='Only run 10 iterations for profiling.')
84+
parser.add_argument('--deterministic', action='store_true')
85+
86+
parser.add_argument("--local_rank", default=0, type=int)
87+
parser.add_argument('--sync_bn', action='store_true',
88+
help='enabling apex sync BN.')
89+
90+
parser.add_argument('--opt-level', type=str)
91+
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
92+
parser.add_argument('--loss-scale', type=str, default=None)
93+
args = parser.parse_args()
94+
return args
10795

10896
def main():
10997
global best_prec1, args
11098

99+
args = parse()
100+
print("opt_level = {}".format(args.opt_level))
101+
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
102+
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
103+
104+
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
105+
106+
cudnn.benchmark = True
107+
best_prec1 = 0
108+
if args.deterministic:
109+
cudnn.benchmark = False
110+
cudnn.deterministic = True
111+
torch.manual_seed(args.local_rank)
112+
torch.set_printoptions(precision=10)
113+
111114
args.distributed = False
112115
if 'WORLD_SIZE' in os.environ:
113116
args.distributed = int(os.environ['WORLD_SIZE']) > 1

0 commit comments

Comments
 (0)