@@ -2,46 +2,54 @@ Getting Started with Distributed Data Parallel
22=================================================
33**Author **: `Shen Li <https://mrshenli.github.io/ >`_
44
5- `DistributedDataParallel <https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html >`__
6- (DDP) implements data parallelism at the module level. It uses communication
7- collectives in the `torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html >`__
8- package to synchronize gradients, parameters, and buffers. Parallelism is
9- available both within a process and across processes. Within a process, DDP
10- replicates the input module to devices specified in ``device_ids ``, scatters
11- inputs along the batch dimension accordingly, and gathers outputs to the
12- ``output_device ``, which is similar to
13- `DataParallel <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html >`__.
14- Across processes, DDP inserts necessary parameter synchronizations in forward
15- passes and gradient synchronizations in backward passes. It is up to users to
16- map processes to available resources, as long as processes do not share GPU
17- devices. The recommended (usually fastest) approach is to create a process for
18- every module replica, i.e., no module replication within a process. The code in
19- this tutorial runs on an 8-GPU server, but it can be easily generalized to
20- other environments.
5+ `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel >`__
6+ (DDP) implements data parallelism at the module level which can run across
7+ multiple machines. Applications using DDP should spawn multiple processes and
8+ create a single DDP instance per process. DDP uses collective communications in the
9+ `torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html >`__
10+ package to synchronize gradients and buffers. More specifically, DDP registers
11+ an autograd hook for each parameter given by ``model.parameters() `` and the
12+ hook will fire when the corresponding gradient is computed in the backward
13+ pass. Then DDP uses that signal to trigger gradient synchronization across
14+ processes. Please refer to
15+ `DDP design note <https://pytorch.org/docs/master/notes/ddp.html >`__ for more details.
16+
17+
18+ The recommended way to use DDP is to spawn one process for each model replica,
19+ where a model replica can span multiple devices. DDP processes can be
20+ placed on the same machine or across machines, but GPU devices cannot be
21+ shared across processes. This tutorial starts from a basic DDP use case and
22+ then demonstrates more advanced use cases including checkpointing models and
23+ combining DDP with model parallel.
24+
25+
26+ .. note ::
27+ The code in this tutorial runs on an 8-GPU server, but it can be easily
28+ generalized to other environments.
29+
2130
2231Comparison between ``DataParallel `` and ``DistributedDataParallel ``
2332-------------------------------------------------------------------
2433
2534Before we dive in, let's clarify why, despite the added complexity, you would
2635consider using ``DistributedDataParallel `` over ``DataParallel ``:
2736
28- - First, recall from the
37+ - First, ``DataParallel `` is single-process, multi-thread, and only works on a
38+ single machine, while ``DistributedDataParallel `` is multi-process and works
39+ for both single- and multi- machine training. ``DataParallel `` is usually
40+ slower than ``DistributedDataParallel `` even on a single machine due to GIL
41+ contention across threads, per-iteration replicated model, and additional
42+ overhead introduced by scattering inputs and gathering outputs.
43+ - Recall from the
2944 `prior tutorial <https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html >`__
3045 that if your model is too large to fit on a single GPU, you must use **model parallel **
3146 to split it across multiple GPUs. ``DistributedDataParallel `` works with
32- **model parallel **; ``DataParallel `` does not at this time.
33- - ``DataParallel `` is single-process, multi-thread, and only works on a single
34- machine, while ``DistributedDataParallel `` is multi-process and works for both
35- single- and multi- machine training. Thus, even for single machine training,
36- where your **data ** is small enough to fit on a single machine, ``DistributedDataParallel ``
37- is expected to be faster than ``DataParallel ``. ``DistributedDataParallel ``
38- also replicates models upfront instead of on each iteration and gets Global
39- Interpreter Lock out of the way.
40- - If both your data is too large to fit on one machine **and ** your
41- model is too large to fit on a single GPU, you can combine model parallel
42- (splitting a single model across multiple GPUs) with ``DistributedDataParallel ``.
43- Under this regime, each ``DistributedDataParallel `` process could use model parallel,
44- and all processes collectively would use data parallel.
47+ **model parallel **; ``DataParallel `` does not at this time. When DDP is combined
48+ with model parallel, each DDP process would use model parallel, and all processes
49+ collectively would use data parallel.
50+ - If your model needs to span multiple machines or if your use case does not fit
51+ into data parallelism paradigm, please see `the RPC API <https://pytorch.org/docs/stable/rpc.html >`__
52+ for more generic distributed training support.
4553
4654Basic Use Case
4755--------------
@@ -70,18 +78,14 @@ be found in
7078 # initialize the process group
7179 dist.init_process_group(" gloo" , rank = rank, world_size = world_size)
7280
73- # Explicitly setting seed to make sure that models created in two processes
74- # start from same random weights and biases.
75- torch.manual_seed(42 )
76-
7781
7882 def cleanup ():
7983 dist.destroy_process_group()
8084
8185 Now, let's create a toy module, wrap it with DDP, and feed it with some dummy
82- input data. Please note, if training starts from random parameters, you might
83- want to make sure that all DDP processes use the same initial values.
84- Otherwise, global gradient synchronizes will not make sense .
86+ input data. Please note, as DDP broadcasts model states from rank 0 process to
87+ all other processes in the DDP constructor, you don't need to worry about
88+ different DDP processes start from different model parameter initial values .
8589
8690.. code :: python
8791
@@ -97,24 +101,19 @@ Otherwise, global gradient synchronizes will not make sense.
97101
98102
99103 def demo_basic (rank , world_size ):
104+ print (f " Running basic DDP example on rank { rank} . " )
100105 setup(rank, world_size)
101106
102- # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
103- # rank 2 uses GPUs [4, 5, 6, 7].
104- n = torch.cuda.device_count() // world_size
105- device_ids = list (range (rank * n, (rank + 1 ) * n))
106-
107- # create model and move it to device_ids[0]
108- model = ToyModel().to(device_ids[0 ])
109- # output_device defaults to device_ids[0]
110- ddp_model = DDP(model, device_ids = device_ids)
107+ # create model and move it to GPU with id rank
108+ model = ToyModel().to(rank)
109+ ddp_model = DDP(model, device_ids = [rank])
111110
112111 loss_fn = nn.MSELoss()
113112 optimizer = optim.SGD(ddp_model.parameters(), lr = 0.001 )
114113
115114 optimizer.zero_grad()
116115 outputs = ddp_model(torch.randn(20 , 10 ))
117- labels = torch.randn(20 , 5 ).to(device_ids[ 0 ] )
116+ labels = torch.randn(20 , 5 ).to(rank )
118117 loss_fn(outputs, labels).backward()
119118 optimizer.step()
120119
@@ -127,23 +126,27 @@ Otherwise, global gradient synchronizes will not make sense.
127126 nprocs = world_size,
128127 join = True )
129128
130- As you can see, DDP wraps lower level distributed communication details, and
131- provides a clean API as if it is a local model. For basic use cases, DDP only
129+ As you can see, DDP wraps lower-level distributed communication details and
130+ provides a clean API as if it is a local model. Gradient synchronization
131+ communications take place during the backward pass and overlap with the
132+ backward computation. When the ``backward() `` returns, ``param.grad `` already
133+ contains the synchronized gradient tensor. For basic use cases, DDP only
132134requires a few more LoCs to set up the process group. When applying DDP to more
133- advanced use cases, there are some caveats that require cautions .
135+ advanced use cases, some caveats require caution .
134136
135137Skewed Processing Speeds
136138------------------------
137139
138- In DDP, constructor, forward method, and differentiation of the outputs are
139- distributed synchronization points. Different processes are expected to reach
140- synchronization points in the same order and enter each synchronization point
141- at roughly the same time. Otherwise, fast processes might arrive early and
142- timeout on waiting for stragglers. Hence, users are responsible for balancing
143- workloads distributions across processes. Sometimes, skewed processing speeds
144- are inevitable due to, e.g., network delays, resource contentions,
145- unpredictable workload spikes. To avoid timeouts in these situations, make
146- sure that you pass a sufficiently large ``timeout `` value when calling
140+ In DDP, the constructor, the forward pass, and the backward pass are
141+ distributed synchronization points. Different processes are expected to launch
142+ the same number of synchronizations and reach these synchronization points in
143+ the same order and enter each synchronization point at roughly the same time.
144+ Otherwise, fast processes might arrive early and timeout on waiting for
145+ stragglers. Hence, users are responsible for balancing workloads distributions
146+ across processes. Sometimes, skewed processing speeds are inevitable due to,
147+ e.g., network delays, resource contentions, unpredictable workload spikes. To
148+ avoid timeouts in these situations, make sure that you pass a sufficiently
149+ large ``timeout `` value when calling
147150`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group >`__.
148151
149152Save and Load Checkpoints
@@ -156,27 +159,23 @@ for more details. When using DDP, one optimization is to save the model in
156159only one process and then load it to all processes, reducing write overhead.
157160This is correct because all processes start from the same parameters and
158161gradients are synchronized in backward passes, and hence optimizers should keep
159- setting parameters to same values. If you use this optimization, make sure all
162+ setting parameters to the same values. If you use this optimization, make sure all
160163processes do not start loading before the saving is finished. Besides, when
161164loading the module, you need to provide an appropriate ``map_location ``
162165argument to prevent a process to step into others' devices. If ``map_location ``
163166is missing, ``torch.load `` will first load the module to CPU and then copy each
164167parameter to where it was saved, which would result in all processes on the
165- same machine using the same set of devices.
168+ same machine using the same set of devices. For more advanced failure recovery
169+ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elastic >`__.
166170
167171.. code :: python
168172
169173 def demo_checkpoint (rank , world_size ):
174+ print (f " Running DDP checkpoint example on rank { rank} . " )
170175 setup(rank, world_size)
171176
172- # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
173- # rank 2 uses GPUs [4, 5, 6, 7].
174- n = torch.cuda.device_count() // world_size
175- device_ids = list (range (rank * n, (rank + 1 ) * n))
176-
177- model = ToyModel().to(device_ids[0 ])
178- # output_device defaults to device_ids[0]
179- ddp_model = DDP(model, device_ids = device_ids)
177+ model = ToyModel().to(rank)
178+ ddp_model = DDP(model, device_ids = [rank])
180179
181180 loss_fn = nn.MSELoss()
182181 optimizer = optim.SGD(ddp_model.parameters(), lr = 0.001 )
@@ -192,15 +191,13 @@ same machine using the same set of devices.
192191 # 0 saves it.
193192 dist.barrier()
194193 # configure map_location properly
195- rank0_devices = [x - rank * len (device_ids) for x in device_ids]
196- device_pairs = zip (rank0_devices, device_ids)
197- map_location = {' cuda:%d ' % x: ' cuda:%d ' % y for x, y in device_pairs}
194+ map_location = {' cuda:%d ' % 0 : ' cuda:%d ' % rank}
198195 ddp_model.load_state_dict(
199196 torch.load(CHECKPOINT_PATH , map_location = map_location))
200197
201198 optimizer.zero_grad()
202199 outputs = ddp_model(torch.randn(20 , 10 ))
203- labels = torch.randn(20 , 5 ).to(device_ids[ 0 ] )
200+ labels = torch.randn(20 , 5 ).to(rank )
204201 loss_fn = nn.MSELoss()
205202 loss_fn(outputs, labels).backward()
206203 optimizer.step()
@@ -217,13 +214,8 @@ same machine using the same set of devices.
217214 Combine DDP with Model Parallelism
218215----------------------------------
219216
220- DDP also works with multi-GPU models, but replications within a process are not
221- supported. You need to create one process per module replica, which usually
222- leads to better performance compared to multiple replicas per process. DDP
223- wrapping multi-GPU models is especially helpful when training large models with
224- a huge amount of data. When using this feature, the multi-GPU model needs to be
225- carefully implemented to avoid hard-coded devices, because different model
226- replicas will be placed to different devices.
217+ DDP also works with multi-GPU models. DDP wrapping multi-GPU models is especially
218+ helpful when training large models with a huge amount of data.
227219
228220.. code :: python
229221
@@ -249,6 +241,7 @@ either the application or the model ``forward()`` method.
249241.. code :: python
250242
251243 def demo_model_parallel (rank , world_size ):
244+ print (f " Running DDP with model parallel example on rank { rank} . " )
252245 setup(rank, world_size)
253246
254247 # setup mp_model and devices for this process
@@ -271,8 +264,10 @@ either the application or the model ``forward()`` method.
271264
272265
273266 if __name__ == " __main__" :
274- run_demo(demo_basic, 2 )
275- run_demo(demo_checkpoint, 2 )
276-
277- if torch.cuda.device_count() >= 8 :
278- run_demo(demo_model_parallel, 4 )
267+ n_gpus = torch.cuda.device_count()
268+ if n_gpus < 8 :
269+ print (f " Requires at least 8 GPUs to run, but got { n_gpus} . " )
270+ else :
271+ run_demo(demo_basic, 8 )
272+ run_demo(demo_checkpoint, 8 )
273+ run_demo(demo_model_parallel, 4 )
0 commit comments