Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2b0e837

Browse files
authored
[contrib][DistributedFusedAdam] Support overlapped grad sync with Megatron pipeline parallelism (#1475)
* Refactor how dist Adam handles overlapped grad sync Each grad bucket independently keeps track of grads that have been generated. Add helper function to create callback functions. Change default param arg in grad norm functions to None. Perform communication for checkpointing in main stream to avoid memory pool overheads. * Support Megatron pipeline parallelism with async grad reduction Enables async grad reduction in first pipeline stage during last backward pass, and disables async grad reduction in all other pipeline stages. * Review suggestions from crcrpar Add unit test for pipeline parallelism with custom sync context. Style tweaks. * Use unittest assert functions in pipeline parallelism test Review suggestion from crcrpar
1 parent 6b5405e commit 2b0e837

4 files changed

Lines changed: 248 additions & 113 deletions

File tree

apex/contrib/optimizers/distributed_fused_adam.py

Lines changed: 110 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ def __init__(self):
260260
self.status = DistributedFusedAdam.GradientStatus.READY
261261
# Request object for asynchronous communication
262262
self.sync_request = None
263+
# Params that have generated grads
264+
self.grads_generated = set()
263265

264266
def sync_wait(self):
265267
"""Wait for asynchronous communication to finish"""
@@ -420,7 +422,6 @@ def __init__(self,
420422

421423
# Objects for gradient synchronization
422424
self._grads_buckets = collections.defaultdict(self.GradientBucket)
423-
self._grads_generated = set()
424425
self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)]
425426

426427
# Scale by factor before optimizer step. Used for grad
@@ -444,16 +445,39 @@ def __init__(self,
444445
# Attach hooks for gradient synchronization
445446
self._register_post_backward_hooks()
446447

448+
# Allocate contiguous gradient buffer if needed
449+
if self.contiguous_grad_buffer:
450+
self._init_grad_buffer()
451+
452+
def _make_post_backward_hook(self, param, param_group_id, param_id):
453+
"""Create callback function to call after param generates grad
454+
455+
Lazily initialize parameter and try launching grad sync.
456+
457+
"""
458+
def hook(*unused):
459+
with self._lock:
460+
need_to_initialize = 'fragments' not in self.state[param]
461+
if need_to_initialize:
462+
self._init_param_state(param, param_group_id, param_id)
463+
if self.greedy_grad_copy:
464+
self._grad_copy(param)
465+
if self.overlap_grad_sync:
466+
self._try_start_bucket_grad_sync(
467+
params=[param],
468+
ignore_last_bucket=need_to_initialize,
469+
)
470+
return hook
471+
447472
def _register_post_backward_hooks(self):
448473
"""Attach hooks for gradient synchronization
449474
450-
Optimizer state for parameters are initialized lazily as they
451-
are encountered in the backward pass.
475+
Also synchronizes param values between processes and counts
476+
number of parameters being optimized.
452477
453478
"""
454479
self._num_grads = 0
455-
grad_buffer_size = 0
456-
self._lock = threading.Lock()
480+
self._lock = threading.Lock() # Not sure if needed
457481
self._grad_accs = []
458482
for param_group_id, group in enumerate(self.param_groups):
459483
for param_id, param in enumerate(group['params']):
@@ -465,40 +489,34 @@ def _register_post_backward_hooks(self):
465489
if param.requires_grad:
466490
self._num_grads += 1
467491

468-
# Callback after gradient is generated
469-
def wrapper(p, p_group_id, p_id):
470-
p_tmp = p.expand_as(p)
471-
grad_acc = p_tmp.grad_fn.next_functions[0][0]
472-
def reduction_hook(*unused):
473-
with self._lock:
474-
if 'fragments' not in self.state[p]:
475-
self._init_param_state(p, p_group_id, p_id)
476-
if self.greedy_grad_copy:
477-
self._grad_copy(p)
478-
if self.overlap_grad_sync:
479-
self._try_start_bucket_grad_sync(
480-
params=[p],
481-
ignore_last_bucket=True,
482-
)
483-
grad_acc.register_hook(reduction_hook)
484-
self._grad_accs.append(grad_acc)
485-
wrapper(param, param_group_id, param_id)
486-
487-
# Gradient size, with padding for alignment
492+
# Register callback for after grad is generated
493+
param_tmp = param.expand_as(param)
494+
grad_acc = param_tmp.grad_fn.next_functions[0][0]
495+
hook = self._make_post_backward_hook(
496+
param,
497+
param_group_id,
498+
param_id,
499+
)
500+
grad_acc.register_hook(hook)
501+
self._grad_accs.append(grad_acc)
502+
503+
def _init_grad_buffer(self):
504+
"""Allocate contiguous buffer for grad buckets"""
505+
grad_buffer_size = 0
506+
for group in self.param_groups:
507+
for param in group['params']:
508+
if param.requires_grad:
488509
grad_size = _round_to_multiple(param.numel(), self.alignment)
489510
grad_buffer_size += grad_size
490-
491-
# Allocate contiguous gradient buffer if needed
492-
if self.contiguous_grad_buffer:
493-
grad_buffer_size = _round_to_multiple(
494-
grad_buffer_size,
495-
self.bucket_size,
496-
)
497-
self._grad_buffer = torch.zeros(
498-
[grad_buffer_size],
499-
dtype=self.dtype,
500-
device=self.device,
501-
)
511+
grad_buffer_size = _round_to_multiple(
512+
grad_buffer_size,
513+
self.bucket_size,
514+
)
515+
self._grad_buffer = torch.zeros(
516+
[grad_buffer_size],
517+
dtype=self.dtype,
518+
device=self.device,
519+
)
502520

503521
def parameters(self):
504522
"""Returns an iterator over optimizer parameters"""
@@ -654,7 +672,6 @@ def zero_grad(self, set_to_none=True):
654672
param.grad.zero_()
655673

656674
# Reset other state
657-
self._grads_generated = set()
658675
self._grad_scale = torch.full([], 1.0, dtype=self.dtype, device=self.device)
659676
self._grad_norm = None
660677

@@ -744,13 +761,10 @@ def _force_bucket_grad_sync(self):
744761
device=self.device,
745762
)
746763

747-
# Reset set of generated gradients
748-
self._grads_generated = set()
749-
750764
def _try_start_bucket_grad_sync(
751765
self,
752766
params=[],
753-
ignore_last_bucket=True,
767+
ignore_last_bucket=False,
754768
):
755769
"""Launches gradient synchronization if enough buckets are ready
756770
@@ -770,38 +784,28 @@ def _try_start_bucket_grad_sync(
770784

771785
# Register params that have generated grads
772786
for param in params:
773-
self._grads_generated.add(param)
774787
for fragment in self.state[param]['fragments']:
775788
bucket_id = fragment.bucket_id
789+
bucket = self._grads_buckets[bucket_id]
776790
bucket_fragments = self.state['buckets'][bucket_id].fragments
777-
is_filled = True
778-
for other_fragment in reversed(bucket_fragments):
779-
param_group_id = other_fragment.param_group_id
780-
param_id = other_fragment.param_id
781-
other_param = self.param_groups[param_group_id]['params'][param_id]
782-
if other_param not in self._grads_generated:
783-
is_filled = False
784-
break
785-
if is_filled:
786-
bucket = self._grads_buckets[bucket_id]
791+
bucket.grads_generated.add(param)
792+
if len(bucket.grads_generated) == len(bucket_fragments):
787793
bucket.status = self.GradientStatus.FULLY_FILLED
788794

789795
# Launch reductions if enough buckets are ready
790-
if len(self._grads_generated) == self._num_grads:
791-
self._force_bucket_grad_sync()
792-
else:
793-
filled_buckets = []
794-
for bucket_id, bucket in sorted(self._grads_buckets.items()):
795-
if ignore_last_bucket and bucket_id == len(self.state['buckets'])-1:
796-
continue
797-
if bucket.status == self.GradientStatus.FULLY_FILLED:
798-
filled_buckets.append(bucket)
799-
pipeline_size = _round_to_multiple(
800-
len(filled_buckets),
801-
self.pipeline_size,
802-
)
803-
if pipeline_size > 0:
804-
self._start_bucket_grad_sync(filled_buckets[:pipeline_size])
796+
filled_buckets = []
797+
for bucket_id, bucket in sorted(self._grads_buckets.items()):
798+
if ignore_last_bucket and bucket_id == len(self.state['buckets'])-1:
799+
continue
800+
if bucket.status == self.GradientStatus.FULLY_FILLED:
801+
filled_buckets.append(bucket)
802+
pipeline_size = _round_to_multiple(
803+
len(filled_buckets),
804+
self.pipeline_size,
805+
round_up=False,
806+
)
807+
if pipeline_size > 0:
808+
self._start_bucket_grad_sync(filled_buckets[:pipeline_size])
805809

806810
def _start_bucket_grad_sync(self, buckets):
807811
"""Synchronize gradient buckets
@@ -827,6 +831,7 @@ def _start_bucket_grad_sync(self, buckets):
827831
# Reduce-scatter over distributed process group
828832
for i, bucket in enumerate(buckets):
829833
bucket.status = self.GradientStatus.SYNCING
834+
bucket.grads_generated.clear()
830835
bucket.sync_wait()
831836
if self.distributed_size == 1:
832837
bucket.sync_grads_shard = bucket.grads_bucket
@@ -932,7 +937,7 @@ def grad_sync(self):
932937
)
933938
self._force_bucket_grad_sync()
934939

935-
def _local_grad_norm(self, parameters=[], norm_type=2.0):
940+
def _local_grad_norm(self, parameters=None, norm_type=2.0):
936941
"""Local contribution to parameter gradient norm
937942
938943
Returns square of 2-norm. Other norms are not yet supported.
@@ -948,7 +953,7 @@ def _local_grad_norm(self, parameters=[], norm_type=2.0):
948953
# Make sure that gradients have been reduced
949954
self.grad_sync()
950955

951-
if not parameters or len(parameters) == self._num_grads:
956+
if parameters is None or len(parameters) == self._num_grads:
952957
# Compute norm of all local gradients
953958
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=self.device)
954959
grad_norm_sq = multi_tensor_applier(
@@ -982,7 +987,7 @@ def _local_grad_norm(self, parameters=[], norm_type=2.0):
982987
grad_norm_sq = grad_norm_sq.view([])
983988
return grad_norm_sq
984989

985-
def grad_norm(self, parameters=[], norm_type=2.0, force=False):
990+
def grad_norm(self, parameters=None, norm_type=2.0, force=False):
986991
"""Gradient norm of parameters in optimizer
987992
988993
The norm is computed over all gradients together, as if they
@@ -1016,7 +1021,7 @@ def grad_norm(self, parameters=[], norm_type=2.0, force=False):
10161021
self._grad_norm = grad_norm_sq.sqrt()
10171022
return self._grad_norm.detach()
10181023

1019-
def clip_grad_norm(self, max_norm, parameters=[], norm_type=2.0):
1024+
def clip_grad_norm(self, max_norm, parameters=None, norm_type=2.0):
10201025
"""Clips gradient norm of parameters in optimizer
10211026
10221027
The norm is computed over all gradients together, as if they
@@ -1330,28 +1335,25 @@ def state_dict(self, gather_on_root=True):
13301335
# Split data into chunks and gather on root rank
13311336
# Note: Assuming we are using the NCCL backend, communication
13321337
# must happen on the GPU. We split the data into fixed-size
1333-
# chunks so that the GPU memory usage is limited to
1334-
# (chunk_size * distributed_size) bytes.
1338+
# chunks to limit GPU memory usage.
13351339
# TODO: Avoid chunking with direct communication between CPUs
13361340
main_stream = torch.cuda.current_stream()
13371341
for stream in self._pipeline_streams:
13381342
stream.wait_stream(main_stream)
13391343
for stream_id, offset in enumerate(range(0, max_state_size, chunk_size)):
13401344
stream_id %= self.pipeline_size
1341-
1342-
# Buffers for chunk
1343-
if self.distributed_rank == 0:
1344-
gathered_chunks = [
1345-
gathered_chunks_buffers[stream_id][i*chunk_size:(i+1)*chunk_size]
1346-
for i in range(self.distributed_size)
1347-
]
1348-
else:
1349-
chunk = chunk_buffers[stream_id]
1350-
1351-
# Perform communication on parallel stream
13521345
stream = self._pipeline_streams[stream_id]
13531346
with torch.cuda.stream(stream):
13541347

1348+
# Buffers for chunk
1349+
if self.distributed_rank == 0:
1350+
gathered_chunks = [
1351+
gathered_chunks_buffers[stream_id][i*chunk_size:(i+1)*chunk_size]
1352+
for i in range(self.distributed_size)
1353+
]
1354+
else:
1355+
chunk = chunk_buffers[stream_id]
1356+
13551357
# Copy to GPU
13561358
if self.distributed_rank != 0 and offset < local_state_size:
13571359
local_chunk_size = min(chunk_size, local_state_size-offset)
@@ -1366,24 +1368,30 @@ def state_dict(self, gather_on_root=True):
13661368
)
13671369

13681370
# Gather on root
1369-
if self.distributed_rank == 0:
1370-
if self._gather_no_copy:
1371-
no_copy_kwarg = { 'no_copy': True }
1371+
# Note: Call in main stream to avoid memory pool
1372+
# overheads from internal memory allocations in
1373+
# gather.
1374+
main_stream.wait_stream(stream)
1375+
with torch.cuda.stream(main_stream):
1376+
if self.distributed_rank == 0:
1377+
if self._gather_no_copy:
1378+
no_copy_kwarg = { 'no_copy': True }
1379+
else:
1380+
no_copy_kwarg = {}
1381+
torch.distributed.gather(
1382+
gathered_chunks[0],
1383+
gathered_chunks,
1384+
dst=self._process_group_ranks[0],
1385+
group=self.process_group,
1386+
**no_copy_kwarg,
1387+
)
13721388
else:
1373-
no_copy_kwarg = {}
1374-
torch.distributed.gather(
1375-
gathered_chunks[0],
1376-
gathered_chunks,
1377-
dst=self._process_group_ranks[0],
1378-
group=self.process_group,
1379-
**no_copy_kwarg,
1380-
)
1381-
else:
1382-
torch.distributed.gather(
1383-
chunk,
1384-
dst=self._process_group_ranks[0],
1385-
group=self.process_group,
1386-
)
1389+
torch.distributed.gather(
1390+
chunk,
1391+
dst=self._process_group_ranks[0],
1392+
group=self.process_group,
1393+
)
1394+
stream.wait_stream(main_stream)
13871395

13881396
# Copy back to CPU
13891397
if self.distributed_rank == 0:

apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from contextlib import contextmanager
1+
import contextlib
22
from typing import List, Union, Optional
33

44
import torch
@@ -20,14 +20,6 @@
2020
_logger = get_transformer_logger(__name__)
2121

2222

23-
@contextmanager
24-
def placeholder_handler():
25-
try:
26-
yield
27-
finally:
28-
pass
29-
30-
3123
def forward_backward_no_pipelining(
3224
forward_step_func: FwdStepFunc,
3325
batch: Batch,
@@ -59,7 +51,7 @@ def forward_backward_no_pipelining(
5951
disable_autocast: Turn off `enabled` flag of `torch.cuda.amp.autocast` if :obj:`True`.
6052
Should be used when your forward and loss computation is in the autocast context to
6153
avoid unnecesarily nest autocast context.
62-
custom_sync_context_handler:
54+
custom_sync_context_handler: Context manager to disable asynchronous gradient reductions.
6355
**kwargs: Added to handle `tensor_shape` which has no effect on this function.
6456
6557
Returns:
@@ -77,7 +69,7 @@ def forward_backward_no_pipelining(
7769
elif isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel):
7870
context_handler = model.no_sync
7971
else:
80-
context_handler = placeholder_handler
72+
context_handler = contextlib.nullcontext
8173

8274
losses_reduced = []
8375
input_tensor, output_tensor_grad = None, None

0 commit comments

Comments
 (0)