Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e086db6

Browse files
authored
update dist.reduce to proper dist.all_reduce (#1926)
1 parent d909b5d commit e086db6

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

intermediate_source/FSDP_tutorial.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ At high level FDSP works as follow:
3535

3636
*In forward path*
3737

38-
* Run allgather to collect all shards from all ranks to recover the full parameter in this FSDP unit
38+
* Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit
3939
* Run forward computation
4040
* Discard parameter shards it has just collected
4141

4242
*In backward path*
4343

44-
* Run allgather to collect all shards from all ranks to recover the full parameter in this FSDP unit
44+
* Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit
4545
* Run backward computation
4646
* Run reduce_scatter to sync gradients
4747
* Discard parameters.
@@ -155,7 +155,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
155155
ddp_loss[0] += loss.item()
156156
ddp_loss[1] += len(data)
157157
158-
dist.reduce(ddp_loss, 0, op=dist.ReduceOp.SUM)
158+
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
159159
if rank == 0:
160160
print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
161161
@@ -176,7 +176,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
176176
ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
177177
ddp_loss[2] += len(data)
178178
179-
dist.reduce(ddp_loss, 0, op=dist.ReduceOp.SUM)
179+
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
180180
181181
if rank == 0:
182182
test_loss = ddp_loss[0] / ddp_loss[2]

0 commit comments

Comments
 (0)