@@ -260,6 +260,8 @@ def __init__(self):
260260 self .status = DistributedFusedAdam .GradientStatus .READY
261261 # Request object for asynchronous communication
262262 self .sync_request = None
263+ # Params that have generated grads
264+ self .grads_generated = set ()
263265
264266 def sync_wait (self ):
265267 """Wait for asynchronous communication to finish"""
@@ -420,7 +422,6 @@ def __init__(self,
420422
421423 # Objects for gradient synchronization
422424 self ._grads_buckets = collections .defaultdict (self .GradientBucket )
423- self ._grads_generated = set ()
424425 self ._pipeline_streams = [torch .cuda .Stream () for _ in range (self .pipeline_size )]
425426
426427 # Scale by factor before optimizer step. Used for grad
@@ -444,16 +445,39 @@ def __init__(self,
444445 # Attach hooks for gradient synchronization
445446 self ._register_post_backward_hooks ()
446447
448+ # Allocate contiguous gradient buffer if needed
449+ if self .contiguous_grad_buffer :
450+ self ._init_grad_buffer ()
451+
452+ def _make_post_backward_hook (self , param , param_group_id , param_id ):
453+ """Create callback function to call after param generates grad
454+
455+ Lazily initialize parameter and try launching grad sync.
456+
457+ """
458+ def hook (* unused ):
459+ with self ._lock :
460+ need_to_initialize = 'fragments' not in self .state [param ]
461+ if need_to_initialize :
462+ self ._init_param_state (param , param_group_id , param_id )
463+ if self .greedy_grad_copy :
464+ self ._grad_copy (param )
465+ if self .overlap_grad_sync :
466+ self ._try_start_bucket_grad_sync (
467+ params = [param ],
468+ ignore_last_bucket = need_to_initialize ,
469+ )
470+ return hook
471+
447472 def _register_post_backward_hooks (self ):
448473 """Attach hooks for gradient synchronization
449474
450- Optimizer state for parameters are initialized lazily as they
451- are encountered in the backward pass .
475+ Also synchronizes param values between processes and counts
476+ number of parameters being optimized .
452477
453478 """
454479 self ._num_grads = 0
455- grad_buffer_size = 0
456- self ._lock = threading .Lock ()
480+ self ._lock = threading .Lock () # Not sure if needed
457481 self ._grad_accs = []
458482 for param_group_id , group in enumerate (self .param_groups ):
459483 for param_id , param in enumerate (group ['params' ]):
@@ -465,40 +489,34 @@ def _register_post_backward_hooks(self):
465489 if param .requires_grad :
466490 self ._num_grads += 1
467491
468- # Callback after gradient is generated
469- def wrapper (p , p_group_id , p_id ):
470- p_tmp = p .expand_as (p )
471- grad_acc = p_tmp .grad_fn .next_functions [0 ][0 ]
472- def reduction_hook (* unused ):
473- with self ._lock :
474- if 'fragments' not in self .state [p ]:
475- self ._init_param_state (p , p_group_id , p_id )
476- if self .greedy_grad_copy :
477- self ._grad_copy (p )
478- if self .overlap_grad_sync :
479- self ._try_start_bucket_grad_sync (
480- params = [p ],
481- ignore_last_bucket = True ,
482- )
483- grad_acc .register_hook (reduction_hook )
484- self ._grad_accs .append (grad_acc )
485- wrapper (param , param_group_id , param_id )
486-
487- # Gradient size, with padding for alignment
492+ # Register callback for after grad is generated
493+ param_tmp = param .expand_as (param )
494+ grad_acc = param_tmp .grad_fn .next_functions [0 ][0 ]
495+ hook = self ._make_post_backward_hook (
496+ param ,
497+ param_group_id ,
498+ param_id ,
499+ )
500+ grad_acc .register_hook (hook )
501+ self ._grad_accs .append (grad_acc )
502+
503+ def _init_grad_buffer (self ):
504+ """Allocate contiguous buffer for grad buckets"""
505+ grad_buffer_size = 0
506+ for group in self .param_groups :
507+ for param in group ['params' ]:
508+ if param .requires_grad :
488509 grad_size = _round_to_multiple (param .numel (), self .alignment )
489510 grad_buffer_size += grad_size
490-
491- # Allocate contiguous gradient buffer if needed
492- if self .contiguous_grad_buffer :
493- grad_buffer_size = _round_to_multiple (
494- grad_buffer_size ,
495- self .bucket_size ,
496- )
497- self ._grad_buffer = torch .zeros (
498- [grad_buffer_size ],
499- dtype = self .dtype ,
500- device = self .device ,
501- )
511+ grad_buffer_size = _round_to_multiple (
512+ grad_buffer_size ,
513+ self .bucket_size ,
514+ )
515+ self ._grad_buffer = torch .zeros (
516+ [grad_buffer_size ],
517+ dtype = self .dtype ,
518+ device = self .device ,
519+ )
502520
503521 def parameters (self ):
504522 """Returns an iterator over optimizer parameters"""
@@ -654,7 +672,6 @@ def zero_grad(self, set_to_none=True):
654672 param .grad .zero_ ()
655673
656674 # Reset other state
657- self ._grads_generated = set ()
658675 self ._grad_scale = torch .full ([], 1.0 , dtype = self .dtype , device = self .device )
659676 self ._grad_norm = None
660677
@@ -744,13 +761,10 @@ def _force_bucket_grad_sync(self):
744761 device = self .device ,
745762 )
746763
747- # Reset set of generated gradients
748- self ._grads_generated = set ()
749-
750764 def _try_start_bucket_grad_sync (
751765 self ,
752766 params = [],
753- ignore_last_bucket = True ,
767+ ignore_last_bucket = False ,
754768 ):
755769 """Launches gradient synchronization if enough buckets are ready
756770
@@ -770,38 +784,28 @@ def _try_start_bucket_grad_sync(
770784
771785 # Register params that have generated grads
772786 for param in params :
773- self ._grads_generated .add (param )
774787 for fragment in self .state [param ]['fragments' ]:
775788 bucket_id = fragment .bucket_id
789+ bucket = self ._grads_buckets [bucket_id ]
776790 bucket_fragments = self .state ['buckets' ][bucket_id ].fragments
777- is_filled = True
778- for other_fragment in reversed (bucket_fragments ):
779- param_group_id = other_fragment .param_group_id
780- param_id = other_fragment .param_id
781- other_param = self .param_groups [param_group_id ]['params' ][param_id ]
782- if other_param not in self ._grads_generated :
783- is_filled = False
784- break
785- if is_filled :
786- bucket = self ._grads_buckets [bucket_id ]
791+ bucket .grads_generated .add (param )
792+ if len (bucket .grads_generated ) == len (bucket_fragments ):
787793 bucket .status = self .GradientStatus .FULLY_FILLED
788794
789795 # Launch reductions if enough buckets are ready
790- if len (self ._grads_generated ) == self ._num_grads :
791- self ._force_bucket_grad_sync ()
792- else :
793- filled_buckets = []
794- for bucket_id , bucket in sorted (self ._grads_buckets .items ()):
795- if ignore_last_bucket and bucket_id == len (self .state ['buckets' ])- 1 :
796- continue
797- if bucket .status == self .GradientStatus .FULLY_FILLED :
798- filled_buckets .append (bucket )
799- pipeline_size = _round_to_multiple (
800- len (filled_buckets ),
801- self .pipeline_size ,
802- )
803- if pipeline_size > 0 :
804- self ._start_bucket_grad_sync (filled_buckets [:pipeline_size ])
796+ filled_buckets = []
797+ for bucket_id , bucket in sorted (self ._grads_buckets .items ()):
798+ if ignore_last_bucket and bucket_id == len (self .state ['buckets' ])- 1 :
799+ continue
800+ if bucket .status == self .GradientStatus .FULLY_FILLED :
801+ filled_buckets .append (bucket )
802+ pipeline_size = _round_to_multiple (
803+ len (filled_buckets ),
804+ self .pipeline_size ,
805+ round_up = False ,
806+ )
807+ if pipeline_size > 0 :
808+ self ._start_bucket_grad_sync (filled_buckets [:pipeline_size ])
805809
806810 def _start_bucket_grad_sync (self , buckets ):
807811 """Synchronize gradient buckets
@@ -827,6 +831,7 @@ def _start_bucket_grad_sync(self, buckets):
827831 # Reduce-scatter over distributed process group
828832 for i , bucket in enumerate (buckets ):
829833 bucket .status = self .GradientStatus .SYNCING
834+ bucket .grads_generated .clear ()
830835 bucket .sync_wait ()
831836 if self .distributed_size == 1 :
832837 bucket .sync_grads_shard = bucket .grads_bucket
@@ -932,7 +937,7 @@ def grad_sync(self):
932937 )
933938 self ._force_bucket_grad_sync ()
934939
935- def _local_grad_norm (self , parameters = [] , norm_type = 2.0 ):
940+ def _local_grad_norm (self , parameters = None , norm_type = 2.0 ):
936941 """Local contribution to parameter gradient norm
937942
938943 Returns square of 2-norm. Other norms are not yet supported.
@@ -948,7 +953,7 @@ def _local_grad_norm(self, parameters=[], norm_type=2.0):
948953 # Make sure that gradients have been reduced
949954 self .grad_sync ()
950955
951- if not parameters or len (parameters ) == self ._num_grads :
956+ if parameters is None or len (parameters ) == self ._num_grads :
952957 # Compute norm of all local gradients
953958 dummy_overflow_buf = torch .zeros ([1 ], dtype = torch .int32 , device = self .device )
954959 grad_norm_sq = multi_tensor_applier (
@@ -982,7 +987,7 @@ def _local_grad_norm(self, parameters=[], norm_type=2.0):
982987 grad_norm_sq = grad_norm_sq .view ([])
983988 return grad_norm_sq
984989
985- def grad_norm (self , parameters = [] , norm_type = 2.0 , force = False ):
990+ def grad_norm (self , parameters = None , norm_type = 2.0 , force = False ):
986991 """Gradient norm of parameters in optimizer
987992
988993 The norm is computed over all gradients together, as if they
@@ -1016,7 +1021,7 @@ def grad_norm(self, parameters=[], norm_type=2.0, force=False):
10161021 self ._grad_norm = grad_norm_sq .sqrt ()
10171022 return self ._grad_norm .detach ()
10181023
1019- def clip_grad_norm (self , max_norm , parameters = [] , norm_type = 2.0 ):
1024+ def clip_grad_norm (self , max_norm , parameters = None , norm_type = 2.0 ):
10201025 """Clips gradient norm of parameters in optimizer
10211026
10221027 The norm is computed over all gradients together, as if they
@@ -1330,28 +1335,25 @@ def state_dict(self, gather_on_root=True):
13301335 # Split data into chunks and gather on root rank
13311336 # Note: Assuming we are using the NCCL backend, communication
13321337 # must happen on the GPU. We split the data into fixed-size
1333- # chunks so that the GPU memory usage is limited to
1334- # (chunk_size * distributed_size) bytes.
1338+ # chunks to limit GPU memory usage.
13351339 # TODO: Avoid chunking with direct communication between CPUs
13361340 main_stream = torch .cuda .current_stream ()
13371341 for stream in self ._pipeline_streams :
13381342 stream .wait_stream (main_stream )
13391343 for stream_id , offset in enumerate (range (0 , max_state_size , chunk_size )):
13401344 stream_id %= self .pipeline_size
1341-
1342- # Buffers for chunk
1343- if self .distributed_rank == 0 :
1344- gathered_chunks = [
1345- gathered_chunks_buffers [stream_id ][i * chunk_size :(i + 1 )* chunk_size ]
1346- for i in range (self .distributed_size )
1347- ]
1348- else :
1349- chunk = chunk_buffers [stream_id ]
1350-
1351- # Perform communication on parallel stream
13521345 stream = self ._pipeline_streams [stream_id ]
13531346 with torch .cuda .stream (stream ):
13541347
1348+ # Buffers for chunk
1349+ if self .distributed_rank == 0 :
1350+ gathered_chunks = [
1351+ gathered_chunks_buffers [stream_id ][i * chunk_size :(i + 1 )* chunk_size ]
1352+ for i in range (self .distributed_size )
1353+ ]
1354+ else :
1355+ chunk = chunk_buffers [stream_id ]
1356+
13551357 # Copy to GPU
13561358 if self .distributed_rank != 0 and offset < local_state_size :
13571359 local_chunk_size = min (chunk_size , local_state_size - offset )
@@ -1366,24 +1368,30 @@ def state_dict(self, gather_on_root=True):
13661368 )
13671369
13681370 # Gather on root
1369- if self .distributed_rank == 0 :
1370- if self ._gather_no_copy :
1371- no_copy_kwarg = { 'no_copy' : True }
1371+ # Note: Call in main stream to avoid memory pool
1372+ # overheads from internal memory allocations in
1373+ # gather.
1374+ main_stream .wait_stream (stream )
1375+ with torch .cuda .stream (main_stream ):
1376+ if self .distributed_rank == 0 :
1377+ if self ._gather_no_copy :
1378+ no_copy_kwarg = { 'no_copy' : True }
1379+ else :
1380+ no_copy_kwarg = {}
1381+ torch .distributed .gather (
1382+ gathered_chunks [0 ],
1383+ gathered_chunks ,
1384+ dst = self ._process_group_ranks [0 ],
1385+ group = self .process_group ,
1386+ ** no_copy_kwarg ,
1387+ )
13721388 else :
1373- no_copy_kwarg = {}
1374- torch .distributed .gather (
1375- gathered_chunks [0 ],
1376- gathered_chunks ,
1377- dst = self ._process_group_ranks [0 ],
1378- group = self .process_group ,
1379- ** no_copy_kwarg ,
1380- )
1381- else :
1382- torch .distributed .gather (
1383- chunk ,
1384- dst = self ._process_group_ranks [0 ],
1385- group = self .process_group ,
1386- )
1389+ torch .distributed .gather (
1390+ chunk ,
1391+ dst = self ._process_group_ranks [0 ],
1392+ group = self .process_group ,
1393+ )
1394+ stream .wait_stream (main_stream )
13871395
13881396 # Copy back to CPU
13891397 if self .distributed_rank == 0 :
0 commit comments