1717
1818import numpy as np
1919
20- try :
21- from apex .parallel import DistributedDataParallel as DDP
22- from apex .fp16_utils import *
23- from apex import amp , optimizers
24- from apex .multi_tensor_apply import multi_tensor_applier
25- except ImportError :
26- raise ImportError ("Please install apex from https://www.github.com/nvidia/apex to run this example." )
20+ from torch .nn .parallel import DistributedDataParallel as DDP
21+
22+ def to_python_float (scalar_tensor : torch .Tensor ):
23+ return scalar_tensor .float ().item ()
2724
2825def fast_collate (batch , memory_format ):
2926
@@ -152,24 +149,9 @@ def main():
152149 momentum = args .momentum ,
153150 weight_decay = args .weight_decay )
154151
155- # Initialize Amp. Amp accepts either values or strings for the optional override arguments,
156- # for convenient interoperation with argparse.
157- model , optimizer = amp .initialize (model , optimizer ,
158- opt_level = args .opt_level ,
159- keep_batchnorm_fp32 = args .keep_batchnorm_fp32 ,
160- loss_scale = args .loss_scale
161- )
162-
163- # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
164- # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
165- # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
166- # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
167152 if args .distributed :
168- # By default, apex.parallel.DistributedDataParallel overlaps communication with
169- # computation in the backward pass.
170- # model = DDP(model)
171- # delay_allreduce delays all communication to the end of the backward pass.
172- model = DDP (model , delay_allreduce = True )
153+ model = DDP (model )
154+ scaler = torch .amp .GradScaler ("cuda" )
173155
174156 # define loss function (criterion) and optimizer
175157 criterion = nn .CrossEntropyLoss ().cuda ()
@@ -245,7 +227,7 @@ def resume():
245227 train_sampler .set_epoch (epoch )
246228
247229 # train for one epoch
248- train (train_loader , model , criterion , optimizer , epoch )
230+ train (train_loader , model , criterion , optimizer , scaler , epoch )
249231
250232 # evaluate on validation set
251233 prec1 = validate (val_loader , model , criterion )
@@ -317,7 +299,7 @@ def next(self):
317299 return input , target
318300
319301
320- def train (train_loader , model , criterion , optimizer , epoch ):
302+ def train (train_loader , model , criterion , optimizer , scaler , epoch ):
321303 batch_time = AverageMeter ()
322304 losses = AverageMeter ()
323305 top1 = AverageMeter ()
@@ -341,24 +323,25 @@ def train(train_loader, model, criterion, optimizer, epoch):
341323 adjust_learning_rate (optimizer , epoch , i , len (train_loader ))
342324
343325 # compute output
344- if args .prof >= 0 : torch .cuda .nvtx .range_push ("forward" )
345- output = model (input )
346- if args .prof >= 0 : torch .cuda .nvtx .range_pop ()
347- loss = criterion (output , target )
326+ with torch .autocast (device_type = "cuda" ):
327+ if args .prof >= 0 : torch .cuda .nvtx .range_push ("forward" )
328+ output = model (input )
329+ if args .prof >= 0 : torch .cuda .nvtx .range_pop ()
330+ loss = criterion (output , target )
348331
349332 # compute gradient and do SGD step
350333 optimizer .zero_grad ()
351334
352335 if args .prof >= 0 : torch .cuda .nvtx .range_push ("backward" )
353- with amp .scale_loss (loss , optimizer ) as scaled_loss :
354- scaled_loss .backward ()
336+ scaler .scale (loss ).backward ()
355337 if args .prof >= 0 : torch .cuda .nvtx .range_pop ()
356338
357339 # for param in model.parameters():
358340 # print(param.data.double().sum().item(), param.grad.data.double().sum().item())
359341
360342 if args .prof >= 0 : torch .cuda .nvtx .range_push ("optimizer.step()" )
361- optimizer .step ()
343+ scaler .step (optimizer )
344+ scaler .update ()
362345 if args .prof >= 0 : torch .cuda .nvtx .range_pop ()
363346
364347 if i % args .print_freq == 0 :
0 commit comments