@@ -66,7 +66,7 @@ consider using ``DistributedDataParallel`` over ``DataParallel``:
6666Basic Use Case
6767--------------
6868
69- To create DDP modules, first set up process groups properly. More details can
69+ To create a DDP module, you must first set up process groups properly. More details can
7070be found in
7171`Writing Distributed Applications with PyTorch <https://pytorch.org/tutorials/intermediate/dist_tuto.html >`__.
7272
@@ -105,10 +105,10 @@ be found in
105105 def cleanup ():
106106 dist.destroy_process_group()
107107
108- Now, let's create a toy module, wrap it with DDP, and feed it with some dummy
108+ Now, let's create a toy module, wrap it with DDP, and feed it some dummy
109109input data. Please note, as DDP broadcasts model states from rank 0 process to
110- all other processes in the DDP constructor, you don't need to worry about
111- different DDP processes start from different model parameter initial values.
110+ all other processes in the DDP constructor, you do not need to worry about
111+ different DDP processes starting from different initial model parameter values.
112112
113113.. code :: python
114114
@@ -150,7 +150,7 @@ different DDP processes start from different model parameter initial values.
150150 join = True )
151151
152152 As you can see, DDP wraps lower-level distributed communication details and
153- provides a clean API as if it is a local model. Gradient synchronization
153+ provides a clean API as if it were a local model. Gradient synchronization
154154communications take place during the backward pass and overlap with the
155155backward computation. When the ``backward() `` returns, ``param.grad `` already
156156contains the synchronized gradient tensor. For basic use cases, DDP only
@@ -164,10 +164,10 @@ In DDP, the constructor, the forward pass, and the backward pass are
164164distributed synchronization points. Different processes are expected to launch
165165the same number of synchronizations and reach these synchronization points in
166166the same order and enter each synchronization point at roughly the same time.
167- Otherwise, fast processes might arrive early and timeout on waiting for
168- stragglers. Hence, users are responsible for balancing workloads distributions
167+ Otherwise, fast processes might arrive early and timeout while waiting for
168+ stragglers. Hence, users are responsible for balancing workload distributions
169169across processes. Sometimes, skewed processing speeds are inevitable due to,
170- e.g., network delays, resource contentions, unpredictable workload spikes. To
170+ e.g., network delays, resource contentions, or unpredictable workload spikes. To
171171avoid timeouts in these situations, make sure that you pass a sufficiently
172172large ``timeout `` value when calling
173173`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group >`__.
@@ -182,10 +182,10 @@ for more details. When using DDP, one optimization is to save the model in
182182only one process and then load it to all processes, reducing write overhead.
183183This is correct because all processes start from the same parameters and
184184gradients are synchronized in backward passes, and hence optimizers should keep
185- setting parameters to the same values. If you use this optimization, make sure all
186- processes do not start loading before the saving is finished. Besides , when
185+ setting parameters to the same values. If you use this optimization, make sure no process starts
186+ loading before the saving is finished. Additionally , when
187187loading the module, you need to provide an appropriate ``map_location ``
188- argument to prevent a process to step into others' devices. If ``map_location ``
188+ argument to prevent a process from stepping into others' devices. If ``map_location ``
189189is missing, ``torch.load `` will first load the module to CPU and then copy each
190190parameter to where it was saved, which would result in all processes on the
191191same machine using the same set of devices. For more advanced failure recovery
@@ -200,8 +200,6 @@ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elast
200200 model = ToyModel().to(rank)
201201 ddp_model = DDP(model, device_ids = [rank])
202202
203- loss_fn = nn.MSELoss()
204- optimizer = optim.SGD(ddp_model.parameters(), lr = 0.001 )
205203
206204 CHECKPOINT_PATH = tempfile.gettempdir() + " /model.checkpoint"
207205 if rank == 0 :
@@ -218,10 +216,13 @@ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elast
218216 ddp_model.load_state_dict(
219217 torch.load(CHECKPOINT_PATH , map_location = map_location))
220218
219+ loss_fn = nn.MSELoss()
220+ optimizer = optim.SGD(ddp_model.parameters(), lr = 0.001 )
221+
221222 optimizer.zero_grad()
222223 outputs = ddp_model(torch.randn(20 , 10 ))
223224 labels = torch.randn(20 , 5 ).to(rank)
224- loss_fn = nn.MSELoss()
225+
225226 loss_fn(outputs, labels).backward()
226227 optimizer.step()
227228
@@ -234,7 +235,7 @@ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elast
234235
235236 cleanup()
236237
237- Combine DDP with Model Parallelism
238+ Combining DDP with Model Parallelism
238239----------------------------------
239240
240241DDP also works with multi-GPU models. DDP wrapping multi-GPU models is especially
0 commit comments