2525except ImportError :
2626 raise ImportError ("Please install apex from https://www.github.com/nvidia/apex to run this example." )
2727
28+ def fast_collate (batch , memory_format ):
2829
29- def fast_collate (batch ):
3030 imgs = [img [0 ] for img in batch ]
3131 targets = torch .tensor ([target [1 ] for target in batch ], dtype = torch .int64 )
3232 w = imgs [0 ].size [0 ]
3333 h = imgs [0 ].size [1 ]
34- tensor = torch .zeros ( (len (imgs ), 3 , h , w ), dtype = torch .uint8 )
34+ tensor = torch .zeros ( (len (imgs ), 3 , h , w ), dtype = torch .uint8 ). contiguous ( memory_format = memory_format )
3535 for i , img in enumerate (imgs ):
3636 nump_array = np .asarray (img , dtype = np .uint8 )
3737 if (nump_array .ndim < 3 ):
3838 nump_array = np .expand_dims (nump_array , axis = - 1 )
3939 nump_array = np .rollaxis (nump_array , 2 )
40-
4140 tensor [i ] += torch .from_numpy (nump_array )
42-
4341 return tensor , targets
4442
4543
@@ -90,6 +88,7 @@ def parse():
9088 parser .add_argument ('--opt-level' , type = str )
9189 parser .add_argument ('--keep-batchnorm-fp32' , type = str , default = None )
9290 parser .add_argument ('--loss-scale' , type = str , default = None )
91+ parser .add_argument ('--channels-last' , type = bool , default = False )
9392 args = parser .parse_args ()
9493 return args
9594
@@ -127,6 +126,11 @@ def main():
127126
128127 assert torch .backends .cudnn .enabled , "Amp requires cudnn backend to be enabled."
129128
129+ if args .channels_last :
130+ memory_format = torch .channels_last
131+ else :
132+ memory_format = torch .contiguous_format
133+
130134 # create model
131135 if args .pretrained :
132136 print ("=> using pre-trained model '{}'" .format (args .arch ))
@@ -140,10 +144,10 @@ def main():
140144 print ("using apex synced BN" )
141145 model = apex .parallel .convert_syncbn_model (model )
142146
143- model = model .cuda ()
147+ model = model .cuda (). to ( memory_format = memory_format )
144148
145149 # Scale learning rate based on global batch size
146- args .lr = args .lr * float (args .batch_size * args .world_size )/ 256.
150+ args .lr = args .lr * float (args .batch_size * args .world_size )/ 256.
147151 optimizer = torch .optim .SGD (model .parameters (), args .lr ,
148152 momentum = args .momentum ,
149153 weight_decay = args .weight_decay )
@@ -161,7 +165,7 @@ def main():
161165 # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
162166 # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
163167 if args .distributed :
164- # By default, apex.parallel.DistributedDataParallel overlaps communication with
168+ # By default, apex.parallel.DistributedDataParallel overlaps communication with
165169 # computation in the backward pass.
166170 # model = DDP(model)
167171 # delay_allreduce delays all communication to the end of the backward pass.
@@ -218,16 +222,18 @@ def resume():
218222 train_sampler = torch .utils .data .distributed .DistributedSampler (train_dataset )
219223 val_sampler = torch .utils .data .distributed .DistributedSampler (val_dataset )
220224
225+ collate_fn = lambda b : fast_collate (b , memory_format )
226+
221227 train_loader = torch .utils .data .DataLoader (
222228 train_dataset , batch_size = args .batch_size , shuffle = (train_sampler is None ),
223- num_workers = args .workers , pin_memory = True , sampler = train_sampler , collate_fn = fast_collate )
229+ num_workers = args .workers , pin_memory = True , sampler = train_sampler , collate_fn = collate_fn )
224230
225231 val_loader = torch .utils .data .DataLoader (
226232 val_dataset ,
227233 batch_size = args .batch_size , shuffle = False ,
228234 num_workers = args .workers , pin_memory = True ,
229235 sampler = val_sampler ,
230- collate_fn = fast_collate )
236+ collate_fn = collate_fn )
231237
232238 if args .evaluate :
233239 validate (val_loader , model , criterion )
@@ -297,7 +303,7 @@ def preload(self):
297303 # else:
298304 self .next_input = self .next_input .float ()
299305 self .next_input = self .next_input .sub_ (self .mean ).div_ (self .std )
300-
306+
301307 def next (self ):
302308 torch .cuda .current_stream ().wait_stream (self .stream )
303309 input = self .next_input
@@ -361,20 +367,20 @@ def train(train_loader, model, criterion, optimizer, epoch):
361367
362368 # Measure accuracy
363369 prec1 , prec5 = accuracy (output .data , target , topk = (1 , 5 ))
364-
365- # Average loss and accuracy across processes for logging
370+
371+ # Average loss and accuracy across processes for logging
366372 if args .distributed :
367373 reduced_loss = reduce_tensor (loss .data )
368374 prec1 = reduce_tensor (prec1 )
369375 prec5 = reduce_tensor (prec5 )
370376 else :
371377 reduced_loss = loss .data
372-
378+
373379 # to_python_float incurs a host<->device sync
374380 losses .update (to_python_float (reduced_loss ), input .size (0 ))
375381 top1 .update (to_python_float (prec1 ), input .size (0 ))
376382 top5 .update (to_python_float (prec5 ), input .size (0 ))
377-
383+
378384 torch .cuda .synchronize ()
379385 batch_time .update ((time .time () - end )/ args .print_freq )
380386 end = time .time ()
0 commit comments