2525except ImportError :
2626 raise ImportError ("Please install apex from https://www.github.com/nvidia/apex to run this example." )
2727
28- model_names = sorted (name for name in models .__dict__
29- if name .islower () and not name .startswith ("__" )
30- and callable (models .__dict__ [name ]))
31-
32- parser = argparse .ArgumentParser (description = 'PyTorch ImageNet Training' )
33- parser .add_argument ('data' , metavar = 'DIR' ,
34- help = 'path to dataset' )
35- parser .add_argument ('--arch' , '-a' , metavar = 'ARCH' , default = 'resnet18' ,
36- choices = model_names ,
37- help = 'model architecture: ' +
38- ' | ' .join (model_names ) +
39- ' (default: resnet18)' )
40- parser .add_argument ('-j' , '--workers' , default = 4 , type = int , metavar = 'N' ,
41- help = 'number of data loading workers (default: 4)' )
42- parser .add_argument ('--epochs' , default = 90 , type = int , metavar = 'N' ,
43- help = 'number of total epochs to run' )
44- parser .add_argument ('--start-epoch' , default = 0 , type = int , metavar = 'N' ,
45- help = 'manual epoch number (useful on restarts)' )
46- parser .add_argument ('-b' , '--batch-size' , default = 256 , type = int ,
47- metavar = 'N' , help = 'mini-batch size per process (default: 256)' )
48- parser .add_argument ('--lr' , '--learning-rate' , default = 0.1 , type = float ,
49- metavar = 'LR' , help = 'Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.' )
50- parser .add_argument ('--momentum' , default = 0.9 , type = float , metavar = 'M' ,
51- help = 'momentum' )
52- parser .add_argument ('--weight-decay' , '--wd' , default = 1e-4 , type = float ,
53- metavar = 'W' , help = 'weight decay (default: 1e-4)' )
54- parser .add_argument ('--print-freq' , '-p' , default = 10 , type = int ,
55- metavar = 'N' , help = 'print frequency (default: 10)' )
56- parser .add_argument ('--resume' , default = '' , type = str , metavar = 'PATH' ,
57- help = 'path to latest checkpoint (default: none)' )
58- parser .add_argument ('-e' , '--evaluate' , dest = 'evaluate' , action = 'store_true' ,
59- help = 'evaluate model on validation set' )
60- parser .add_argument ('--pretrained' , dest = 'pretrained' , action = 'store_true' ,
61- help = 'use pre-trained model' )
62-
63- parser .add_argument ('--prof' , default = - 1 , type = int ,
64- help = 'Only run 10 iterations for profiling.' )
65- parser .add_argument ('--deterministic' , action = 'store_true' )
66-
67- parser .add_argument ("--local_rank" , default = 0 , type = int )
68- parser .add_argument ('--sync_bn' , action = 'store_true' ,
69- help = 'enabling apex sync BN.' )
70-
71- parser .add_argument ('--opt-level' , type = str )
72- parser .add_argument ('--keep-batchnorm-fp32' , type = str , default = None )
73- parser .add_argument ('--loss-scale' , type = str , default = None )
74-
75- cudnn .benchmark = True
7628
7729def fast_collate (batch ):
7830 imgs = [img [0 ] for img in batch ]
@@ -90,24 +42,75 @@ def fast_collate(batch):
9042
9143 return tensor , targets
9244
93- best_prec1 = 0
94- args = parser .parse_args ()
9545
96- print ("opt_level = {}" .format (args .opt_level ))
97- print ("keep_batchnorm_fp32 = {}" .format (args .keep_batchnorm_fp32 ), type (args .keep_batchnorm_fp32 ))
98- print ("loss_scale = {}" .format (args .loss_scale ), type (args .loss_scale ))
99-
100- print ("\n CUDNN VERSION: {}\n " .format (torch .backends .cudnn .version ()))
46+ def parse ():
47+ model_names = sorted (name for name in models .__dict__
48+ if name .islower () and not name .startswith ("__" )
49+ and callable (models .__dict__ [name ]))
10150
102- if args .deterministic :
103- cudnn .benchmark = False
104- cudnn .deterministic = True
105- torch .manual_seed (args .local_rank )
106- torch .set_printoptions (precision = 10 )
51+ parser = argparse .ArgumentParser (description = 'PyTorch ImageNet Training' )
52+ parser .add_argument ('data' , metavar = 'DIR' ,
53+ help = 'path to dataset' )
54+ parser .add_argument ('--arch' , '-a' , metavar = 'ARCH' , default = 'resnet18' ,
55+ choices = model_names ,
56+ help = 'model architecture: ' +
57+ ' | ' .join (model_names ) +
58+ ' (default: resnet18)' )
59+ parser .add_argument ('-j' , '--workers' , default = 4 , type = int , metavar = 'N' ,
60+ help = 'number of data loading workers (default: 4)' )
61+ parser .add_argument ('--epochs' , default = 90 , type = int , metavar = 'N' ,
62+ help = 'number of total epochs to run' )
63+ parser .add_argument ('--start-epoch' , default = 0 , type = int , metavar = 'N' ,
64+ help = 'manual epoch number (useful on restarts)' )
65+ parser .add_argument ('-b' , '--batch-size' , default = 256 , type = int ,
66+ metavar = 'N' , help = 'mini-batch size per process (default: 256)' )
67+ parser .add_argument ('--lr' , '--learning-rate' , default = 0.1 , type = float ,
68+ metavar = 'LR' , help = 'Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.' )
69+ parser .add_argument ('--momentum' , default = 0.9 , type = float , metavar = 'M' ,
70+ help = 'momentum' )
71+ parser .add_argument ('--weight-decay' , '--wd' , default = 1e-4 , type = float ,
72+ metavar = 'W' , help = 'weight decay (default: 1e-4)' )
73+ parser .add_argument ('--print-freq' , '-p' , default = 10 , type = int ,
74+ metavar = 'N' , help = 'print frequency (default: 10)' )
75+ parser .add_argument ('--resume' , default = '' , type = str , metavar = 'PATH' ,
76+ help = 'path to latest checkpoint (default: none)' )
77+ parser .add_argument ('-e' , '--evaluate' , dest = 'evaluate' , action = 'store_true' ,
78+ help = 'evaluate model on validation set' )
79+ parser .add_argument ('--pretrained' , dest = 'pretrained' , action = 'store_true' ,
80+ help = 'use pre-trained model' )
81+
82+ parser .add_argument ('--prof' , default = - 1 , type = int ,
83+ help = 'Only run 10 iterations for profiling.' )
84+ parser .add_argument ('--deterministic' , action = 'store_true' )
85+
86+ parser .add_argument ("--local_rank" , default = 0 , type = int )
87+ parser .add_argument ('--sync_bn' , action = 'store_true' ,
88+ help = 'enabling apex sync BN.' )
89+
90+ parser .add_argument ('--opt-level' , type = str )
91+ parser .add_argument ('--keep-batchnorm-fp32' , type = str , default = None )
92+ parser .add_argument ('--loss-scale' , type = str , default = None )
93+ args = parser .parse_args ()
94+ return args
10795
10896def main ():
10997 global best_prec1 , args
11098
99+ args = parse ()
100+ print ("opt_level = {}" .format (args .opt_level ))
101+ print ("keep_batchnorm_fp32 = {}" .format (args .keep_batchnorm_fp32 ), type (args .keep_batchnorm_fp32 ))
102+ print ("loss_scale = {}" .format (args .loss_scale ), type (args .loss_scale ))
103+
104+ print ("\n CUDNN VERSION: {}\n " .format (torch .backends .cudnn .version ()))
105+
106+ cudnn .benchmark = True
107+ best_prec1 = 0
108+ if args .deterministic :
109+ cudnn .benchmark = False
110+ cudnn .deterministic = True
111+ torch .manual_seed (args .local_rank )
112+ torch .set_printoptions (precision = 10 )
113+
111114 args .distributed = False
112115 if 'WORLD_SIZE' in os .environ :
113116 args .distributed = int (os .environ ['WORLD_SIZE' ]) > 1
0 commit comments