@@ -35,13 +35,13 @@ At high level FDSP works as follow:
3535
3636*In forward path *
3737
38- * Run allgather to collect all shards from all ranks to recover the full parameter in this FSDP unit
38+ * Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit
3939* Run forward computation
4040* Discard parameter shards it has just collected
4141
4242*In backward path *
4343
44- * Run allgather to collect all shards from all ranks to recover the full parameter in this FSDP unit
44+ * Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit
4545* Run backward computation
4646* Run reduce_scatter to sync gradients
4747* Discard parameters.
@@ -155,7 +155,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
155155 ddp_loss[0 ] += loss.item()
156156 ddp_loss[1 ] += len (data)
157157
158- dist.reduce (ddp_loss, 0 , op = dist.ReduceOp.SUM )
158+ dist.all_reduce (ddp_loss, op = dist.ReduceOp.SUM )
159159 if rank == 0 :
160160 print (' Train Epoch: {} \t Loss: {:.6f } ' .format(epoch, ddp_loss[0 ] / ddp_loss[1 ]))
161161
@@ -176,7 +176,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
176176 ddp_loss[1 ] += pred.eq(target.view_as(pred)).sum().item()
177177 ddp_loss[2 ] += len (data)
178178
179- dist.reduce (ddp_loss, 0 , op = dist.ReduceOp.SUM )
179+ dist.all_reduce (ddp_loss, op = dist.ReduceOp.SUM )
180180
181181 if rank == 0 :
182182 test_loss = ddp_loss[0 ] / ddp_loss[2 ]
0 commit comments