@@ -774,10 +774,15 @@ def __init__(
774774 Tuple [torch .dtype , torch .dtype , torch .dtype ], torch .Tensor
775775 ] = {}
776776
777- # Side streams for optimizer step and communication
777+ # Side streams for state dict communication
778778 self ._pipeline_streams : List [torch .cuda .Stream ] = [
779- torch .cuda .Stream () for _ in range (self .pipeline_size + 1 )
779+ torch .cuda .Stream () for _ in range (self .pipeline_size )
780780 ]
781+ # Side streams for gradients and parameters communication
782+ self ._comm_streams : List [torch .cuda .Stream ] = [
783+ torch .cuda .Stream () for _ in range (self .pipeline_size )
784+ ]
785+ self ._last_comm_stream_id : int = - 1
781786
782787 # Scale by factor before optimizer step. Used for grad
783788 # clipping and gradient scaler.
@@ -1951,8 +1956,11 @@ def _start_bucket_grad_sync(self, buckets: List[GradientBucket]) -> None:
19511956 bucket .grads_shard = bucket .grads_shard .clone ()
19521957
19531958 # Side stream for communication
1959+ # If new bucket is ready before last bucket communication finishes, use multiple
1960+ # communication streams could help pipeline reduce-scatter and all-reduce.
19541961 main_stream = torch .cuda .current_stream ()
1955- comm_stream = self ._pipeline_streams [- 1 ]
1962+ self ._last_comm_stream_id = (self ._last_comm_stream_id + 1 ) % len (self ._comm_streams )
1963+ comm_stream = self ._comm_streams [self ._last_comm_stream_id ]
19561964 comm_stream .wait_stream (main_stream )
19571965
19581966 # Reduce-scatter over distributed process group
@@ -1995,8 +2003,8 @@ def _start_bucket_grad_sync(self, buckets: List[GradientBucket]) -> None:
19952003 def _finish_bucket_grad_sync (self ) -> None :
19962004 """Wait for any gradient synchronizations that are in progress"""
19972005 main_stream = torch .cuda .current_stream ()
1998- comm_stream = self ._pipeline_streams [ - 1 ]
1999- main_stream .wait_stream (comm_stream )
2006+ for comm_stream in self ._comm_streams :
2007+ main_stream .wait_stream (comm_stream )
20002008 for bucket_id , bucket in sorted (self ._grads_buckets .items ()):
20012009 if bucket .status == self .GradientStatus .SYNCING :
20022010 # Accumulate gradient in local shard
@@ -2103,7 +2111,8 @@ def _start_bucket_param_sync(self, buckets: List[ParameterBucket]) -> None:
21032111
21042112 # Side stream for communication
21052113 main_stream = torch .cuda .current_stream ()
2106- comm_stream = self ._pipeline_streams [- 1 ]
2114+ self ._last_comm_stream_id = (self ._last_comm_stream_id + 1 ) % len (self ._comm_streams )
2115+ comm_stream = self ._comm_streams [self ._last_comm_stream_id ]
21072116 comm_stream .wait_stream (main_stream )
21082117
21092118 # All-gather over distributed process group
@@ -2126,8 +2135,8 @@ def _start_bucket_param_sync(self, buckets: List[ParameterBucket]) -> None:
21262135 def _finish_bucket_param_sync (self ) -> None :
21272136 """Wait for any param synchronizations that are in progress"""
21282137 main_stream = torch .cuda .current_stream ()
2129- comm_stream = self ._pipeline_streams [ - 1 ]
2130- main_stream .wait_stream (comm_stream )
2138+ for comm_stream in self ._comm_streams :
2139+ main_stream .wait_stream (comm_stream )
21312140 for bucket_id , bucket in self ._params_buckets .items ():
21322141 if bucket .status == self .ParameterStatus .SYNCING :
21332142 bucket .params_shard = None
0 commit comments