Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a730f38

Browse files
Adding simple distributed example for #200
1 parent 7f0d8c8 commit a730f38

4 files changed

Lines changed: 82 additions & 2 deletions

File tree

examples/imagenet/main_amp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def main():
117117
args.world_size = 1
118118

119119
if args.distributed:
120-
args.gpu = args.local_rank % torch.cuda.device_count()
120+
args.gpu = args.local_rank
121121
torch.cuda.set_device(args.gpu)
122122
torch.distributed.init_process_group(backend='nccl',
123123
init_method='env://')
@@ -334,7 +334,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
334334
if args.prof: torch.cuda.nvtx.range_pop()
335335

336336
if i%args.print_freq == 0:
337-
# Every print_freq iterations, check the loss accuracy and speed.
337+
# Every print_freq iterations, check the loss, accuracy, and speed.
338338
# For best performance, it doesn't make sense to print these metrics every
339339
# iteration, since they incur an allreduce and some host<->device syncs.
340340

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
**distributed_data_parallel.py** and **run.sh** show an example using Amp with
2+
[apex.parallel.DistributedDataParallel](https://nvidia.github.io/apex/parallel.html) or
3+
[torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#distributeddataparallel)
4+
and the Pytorch multiprocess launcher script,
5+
[torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility).
6+
The use of `Amp` with distributed does not need to change from ordinary
7+
single-process use. The only gotcha is that wrapping your model with `DistributedDataParallel` must
8+
come after the call to `amp.initialize`. Test via
9+
```bash
10+
bash run.sh
11+
```
12+
13+
**This is intended purely as an instructional example, not a performance showcase.**
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import argparse
3+
import os
4+
from apex import amp
5+
# FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead)
6+
from apex.parallel import DistributedDataParallel
7+
8+
parser = argparse.ArgumentParser()
9+
# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
10+
# automatically by torch.distributed.launch.
11+
parser.add_argument("--local_rank", default=0, type=int)
12+
args = parser.parse_args()
13+
14+
# FOR DISTRIBUTED: If we are running under torch.distributed.launch,
15+
# the 'WORLD_SIZE' environment variable will also be set automatically.
16+
args.distributed = False
17+
if 'WORLD_SIZE' in os.environ:
18+
args.distributed = int(os.environ['WORLD_SIZE']) > 1
19+
20+
if args.distributed:
21+
# FOR DISTRIBUTED: Set the device according to local_rank.
22+
torch.cuda.set_device(args.local_rank)
23+
24+
# FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide
25+
# environment variables, and requires that you use init_method=`env://`.
26+
torch.distributed.init_process_group(backend='nccl',
27+
init_method='env://')
28+
29+
torch.backends.cudnn.benchmark = True
30+
31+
N, D_in, D_out = 64, 1024, 16
32+
33+
# Each process receives its own batch of "fake input data" and "fake target data."
34+
# The "training loop" in each process just uses this fake batch over and over.
35+
# https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic
36+
# example of distributed data sampling for both training and validation.
37+
x = torch.randn(N, D_in, device='cuda')
38+
y = torch.randn(N, D_out, device='cuda')
39+
40+
model = torch.nn.Linear(D_in, D_out).cuda()
41+
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
42+
43+
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
44+
45+
if args.distributed:
46+
# FOR DISTRIBUTED: After amp.initialize, wrap the model with
47+
# apex.parallel.DistributedDataParallel.
48+
model = DistributedDataParallel(model)
49+
# torch.nn.parallel.DistributedDataParallel is also fine, with some added args:
50+
# model = torch.nn.parallel.DistributedDataParallel(model,
51+
# device_ids=[args.local_rank],
52+
# output_device=args.local_rank)
53+
54+
loss_fn = torch.nn.MSELoss()
55+
56+
for t in range(500):
57+
optimizer.zero_grad()
58+
y_pred = model(x)
59+
loss = loss_fn(y_pred, y)
60+
with amp.scale_loss(loss, optimizer) as scaled_loss:
61+
scaled_loss.backward()
62+
optimizer.step()
63+
64+
if args.local_rank == 0:
65+
print("final loss = ", loss)

examples/simple/distributed/run.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#!/bin/bash
2+
python -m torch.distributed.launch --nproc_per_node=2 distributed_data_parallel.py

0 commit comments

Comments
 (0)