Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b9d758c

Browse files
authored
Pipeline reduce-scatter and all-reduce. (#1895)
1 parent badf311 commit b9d758c

1 file changed

Lines changed: 17 additions & 8 deletions

File tree

apex/contrib/optimizers/distributed_fused_adam.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -774,10 +774,15 @@ def __init__(
774774
Tuple[torch.dtype, torch.dtype, torch.dtype], torch.Tensor
775775
] = {}
776776

777-
# Side streams for optimizer step and communication
777+
# Side streams for state dict communication
778778
self._pipeline_streams: List[torch.cuda.Stream] = [
779-
torch.cuda.Stream() for _ in range(self.pipeline_size + 1)
779+
torch.cuda.Stream() for _ in range(self.pipeline_size)
780780
]
781+
# Side streams for gradients and parameters communication
782+
self._comm_streams: List[torch.cuda.Stream] = [
783+
torch.cuda.Stream() for _ in range(self.pipeline_size)
784+
]
785+
self._last_comm_stream_id: int = -1
781786

782787
# Scale by factor before optimizer step. Used for grad
783788
# clipping and gradient scaler.
@@ -1951,8 +1956,11 @@ def _start_bucket_grad_sync(self, buckets: List[GradientBucket]) -> None:
19511956
bucket.grads_shard = bucket.grads_shard.clone()
19521957

19531958
# Side stream for communication
1959+
# If new bucket is ready before last bucket communication finishes, use multiple
1960+
# communication streams could help pipeline reduce-scatter and all-reduce.
19541961
main_stream = torch.cuda.current_stream()
1955-
comm_stream = self._pipeline_streams[-1]
1962+
self._last_comm_stream_id = (self._last_comm_stream_id + 1) % len(self._comm_streams)
1963+
comm_stream = self._comm_streams[self._last_comm_stream_id]
19561964
comm_stream.wait_stream(main_stream)
19571965

19581966
# Reduce-scatter over distributed process group
@@ -1995,8 +2003,8 @@ def _start_bucket_grad_sync(self, buckets: List[GradientBucket]) -> None:
19952003
def _finish_bucket_grad_sync(self) -> None:
19962004
"""Wait for any gradient synchronizations that are in progress"""
19972005
main_stream = torch.cuda.current_stream()
1998-
comm_stream = self._pipeline_streams[-1]
1999-
main_stream.wait_stream(comm_stream)
2006+
for comm_stream in self._comm_streams:
2007+
main_stream.wait_stream(comm_stream)
20002008
for bucket_id, bucket in sorted(self._grads_buckets.items()):
20012009
if bucket.status == self.GradientStatus.SYNCING:
20022010
# Accumulate gradient in local shard
@@ -2103,7 +2111,8 @@ def _start_bucket_param_sync(self, buckets: List[ParameterBucket]) -> None:
21032111

21042112
# Side stream for communication
21052113
main_stream = torch.cuda.current_stream()
2106-
comm_stream = self._pipeline_streams[-1]
2114+
self._last_comm_stream_id = (self._last_comm_stream_id + 1) % len(self._comm_streams)
2115+
comm_stream = self._comm_streams[self._last_comm_stream_id]
21072116
comm_stream.wait_stream(main_stream)
21082117

21092118
# All-gather over distributed process group
@@ -2126,8 +2135,8 @@ def _start_bucket_param_sync(self, buckets: List[ParameterBucket]) -> None:
21262135
def _finish_bucket_param_sync(self) -> None:
21272136
"""Wait for any param synchronizations that are in progress"""
21282137
main_stream = torch.cuda.current_stream()
2129-
comm_stream = self._pipeline_streams[-1]
2130-
main_stream.wait_stream(comm_stream)
2138+
for comm_stream in self._comm_streams:
2139+
main_stream.wait_stream(comm_stream)
21312140
for bucket_id, bucket in self._params_buckets.items():
21322141
if bucket.status == self.ParameterStatus.SYNCING:
21332142
bucket.params_shard = None

0 commit comments

Comments
 (0)