Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2ca894d

Browse files
VitalyFedyuninmcarilli
authored andcommitted
Channels last support (#668)
1 parent b66ffc1 commit 2ca894d

1 file changed

Lines changed: 20 additions & 14 deletions

File tree

examples/imagenet/main_amp.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,19 @@
2525
except ImportError:
2626
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
2727

28+
def fast_collate(batch, memory_format):
2829

29-
def fast_collate(batch):
3030
imgs = [img[0] for img in batch]
3131
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
3232
w = imgs[0].size[0]
3333
h = imgs[0].size[1]
34-
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
34+
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format)
3535
for i, img in enumerate(imgs):
3636
nump_array = np.asarray(img, dtype=np.uint8)
3737
if(nump_array.ndim < 3):
3838
nump_array = np.expand_dims(nump_array, axis=-1)
3939
nump_array = np.rollaxis(nump_array, 2)
40-
4140
tensor[i] += torch.from_numpy(nump_array)
42-
4341
return tensor, targets
4442

4543

@@ -90,6 +88,7 @@ def parse():
9088
parser.add_argument('--opt-level', type=str)
9189
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
9290
parser.add_argument('--loss-scale', type=str, default=None)
91+
parser.add_argument('--channels-last', type=bool, default=False)
9392
args = parser.parse_args()
9493
return args
9594

@@ -127,6 +126,11 @@ def main():
127126

128127
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
129128

129+
if args.channels_last:
130+
memory_format = torch.channels_last
131+
else:
132+
memory_format = torch.contiguous_format
133+
130134
# create model
131135
if args.pretrained:
132136
print("=> using pre-trained model '{}'".format(args.arch))
@@ -140,10 +144,10 @@ def main():
140144
print("using apex synced BN")
141145
model = apex.parallel.convert_syncbn_model(model)
142146

143-
model = model.cuda()
147+
model = model.cuda().to(memory_format=memory_format)
144148

145149
# Scale learning rate based on global batch size
146-
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
150+
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
147151
optimizer = torch.optim.SGD(model.parameters(), args.lr,
148152
momentum=args.momentum,
149153
weight_decay=args.weight_decay)
@@ -161,7 +165,7 @@ def main():
161165
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
162166
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
163167
if args.distributed:
164-
# By default, apex.parallel.DistributedDataParallel overlaps communication with
168+
# By default, apex.parallel.DistributedDataParallel overlaps communication with
165169
# computation in the backward pass.
166170
# model = DDP(model)
167171
# delay_allreduce delays all communication to the end of the backward pass.
@@ -218,16 +222,18 @@ def resume():
218222
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
219223
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
220224

225+
collate_fn = lambda b: fast_collate(b, memory_format)
226+
221227
train_loader = torch.utils.data.DataLoader(
222228
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
223-
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
229+
num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=collate_fn)
224230

225231
val_loader = torch.utils.data.DataLoader(
226232
val_dataset,
227233
batch_size=args.batch_size, shuffle=False,
228234
num_workers=args.workers, pin_memory=True,
229235
sampler=val_sampler,
230-
collate_fn=fast_collate)
236+
collate_fn=collate_fn)
231237

232238
if args.evaluate:
233239
validate(val_loader, model, criterion)
@@ -297,7 +303,7 @@ def preload(self):
297303
# else:
298304
self.next_input = self.next_input.float()
299305
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
300-
306+
301307
def next(self):
302308
torch.cuda.current_stream().wait_stream(self.stream)
303309
input = self.next_input
@@ -361,20 +367,20 @@ def train(train_loader, model, criterion, optimizer, epoch):
361367

362368
# Measure accuracy
363369
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
364-
365-
# Average loss and accuracy across processes for logging
370+
371+
# Average loss and accuracy across processes for logging
366372
if args.distributed:
367373
reduced_loss = reduce_tensor(loss.data)
368374
prec1 = reduce_tensor(prec1)
369375
prec5 = reduce_tensor(prec5)
370376
else:
371377
reduced_loss = loss.data
372-
378+
373379
# to_python_float incurs a host<->device sync
374380
losses.update(to_python_float(reduced_loss), input.size(0))
375381
top1.update(to_python_float(prec1), input.size(0))
376382
top5.update(to_python_float(prec5), input.size(0))
377-
383+
378384
torch.cuda.synchronize()
379385
batch_time.update((time.time() - end)/args.print_freq)
380386
end = time.time()

0 commit comments

Comments
 (0)