Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f29b3f8

Browse files
Make main_amp.py more profiling-friendly
1 parent 4b9858e commit f29b3f8

2 files changed

Lines changed: 29 additions & 14 deletions

File tree

examples/imagenet/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,9 @@ Running with the `--deterministic` flag should produce bitwise identical outputs
173173
regardless of what other options are used (see [Pytorch docs on reproducibility](https://pytorch.org/docs/stable/notes/randomness.html)).
174174
Since `--deterministic` disables `torch.backends.cudnn.benchmark`, `--deterministic` may
175175
cause a modest performance decrease.
176+
177+
## Profiling
178+
179+
If you're curious how the network actually looks on the CPU and GPU timelines (for example, how good is the overall utilization?
180+
Is the prefetcher really overlapping data transfers?) try profiling `main_amp.py`.
181+
[Detailed instructions can be found here](https://gist.github.com/mcarilli/213a4e698e4a0ae2234ddee56f4f3f95).

examples/imagenet/main_amp.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
6161
help='use pre-trained model')
6262

63-
parser.add_argument('--prof', dest='prof', action='store_true',
63+
parser.add_argument('--prof', default=-1, type=int,
6464
help='Only run 10 iterations for profiling.')
6565
parser.add_argument('--deterministic', action='store_true')
6666

@@ -236,8 +236,7 @@ def resume():
236236

237237
# train for one epoch
238238
train(train_loader, model, criterion, optimizer, epoch)
239-
if args.prof:
240-
break
239+
241240
# evaluate on validation set
242241
prec1 = validate(val_loader, model, criterion)
243242

@@ -323,33 +322,34 @@ def train(train_loader, model, criterion, optimizer, epoch):
323322
i = 0
324323
while input is not None:
325324
i += 1
325+
if args.prof >= 0 and i == args.prof:
326+
print("Profiling begun at iteration {}".format(i))
327+
torch.cuda.cudart().cudaProfilerStart()
326328

327-
adjust_learning_rate(optimizer, epoch, i, len(train_loader))
329+
if args.prof >= 0: torch.cuda.nvtx.range_push("Body of iteration {}".format(i))
328330

329-
if args.prof:
330-
if i > 10:
331-
break
331+
adjust_learning_rate(optimizer, epoch, i, len(train_loader))
332332

333333
# compute output
334-
if args.prof: torch.cuda.nvtx.range_push("forward")
334+
if args.prof >= 0: torch.cuda.nvtx.range_push("forward")
335335
output = model(input)
336-
if args.prof: torch.cuda.nvtx.range_pop()
336+
if args.prof >= 0: torch.cuda.nvtx.range_pop()
337337
loss = criterion(output, target)
338338

339339
# compute gradient and do SGD step
340340
optimizer.zero_grad()
341341

342-
if args.prof: torch.cuda.nvtx.range_push("backward")
342+
if args.prof >= 0: torch.cuda.nvtx.range_push("backward")
343343
with amp.scale_loss(loss, optimizer) as scaled_loss:
344344
scaled_loss.backward()
345-
if args.prof: torch.cuda.nvtx.range_pop()
345+
if args.prof >= 0: torch.cuda.nvtx.range_pop()
346346

347347
# for param in model.parameters():
348348
# print(param.data.double().sum().item(), param.grad.data.double().sum().item())
349349

350-
if args.prof: torch.cuda.nvtx.range_push("step")
350+
if args.prof >= 0: torch.cuda.nvtx.range_push("optimizer.step()")
351351
optimizer.step()
352-
if args.prof: torch.cuda.nvtx.range_pop()
352+
if args.prof >= 0: torch.cuda.nvtx.range_pop()
353353

354354
if i%args.print_freq == 0:
355355
# Every print_freq iterations, check the loss, accuracy, and speed.
@@ -388,8 +388,17 @@ def train(train_loader, model, criterion, optimizer, epoch):
388388
args.world_size*args.batch_size/batch_time.avg,
389389
batch_time=batch_time,
390390
loss=losses, top1=top1, top5=top5))
391-
391+
if args.prof >= 0: torch.cuda.nvtx.range_push("prefetcher.next()")
392392
input, target = prefetcher.next()
393+
if args.prof >= 0: torch.cuda.nvtx.range_pop()
394+
395+
# Pop range "Body of iteration {}".format(i)
396+
if args.prof >= 0: torch.cuda.nvtx.range_pop()
397+
398+
if args.prof >= 0 and i == args.prof + 10:
399+
print("Profiling ended at iteration {}".format(i))
400+
torch.cuda.cudart().cudaProfilerStop()
401+
quit()
393402

394403

395404
def validate(val_loader, model, criterion):

0 commit comments

Comments
 (0)