Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 03c9d80

Browse files
authored
Decouple optimizer state and grad dtypes in distributed Adam optimizer (#1575)
* Decouple distopt dtypes for grads and optim state * Automatically detect grad dtype for Transformer layer wgrad fusion * Review suggestions from @crcrpar
1 parent 0c8400a commit 03c9d80

5 files changed

Lines changed: 97 additions & 106 deletions

File tree

apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ struct DistAdamFunctor
200200
* to store FP32 main params. Instead, store 16-bit param remainder
201201
* and combine with BF16 param to reconstruct the FP32 main param.
202202
*/
203+
template <typename GRAD_T>
203204
struct DistAdamWithParamRemaindersFunctor
204205
{
205206
__device__ __forceinline__ void operator()(
@@ -230,7 +231,7 @@ struct DistAdamWithParamRemaindersFunctor
230231
m += chunk_idx*chunk_size;
231232
float* v = (float *)tl.addresses[3][tensor_loc];
232233
v += chunk_idx*chunk_size;
233-
float* g = (float *)tl.addresses[4][tensor_loc];
234+
const GRAD_T* g = (GRAD_T *)tl.addresses[4][tensor_loc];
234235
g += chunk_idx*chunk_size;
235236
int16_t* p_out = (int16_t *)tl.addresses[5][tensor_loc];
236237
p_out += chunk_idx*chunk_size;
@@ -256,7 +257,7 @@ struct DistAdamWithParamRemaindersFunctor
256257
int16_t local_p_rem[ILP];
257258
float local_m[ILP];
258259
float local_v[ILP];
259-
float local_g[ILP];
260+
GRAD_T local_g[ILP];
260261

261262
// Load
262263
if (aligned) {
@@ -294,7 +295,7 @@ struct DistAdamWithParamRemaindersFunctor
294295
}
295296

296297
// Local compute
297-
using LocalFunctor = DistAdamFunctor<float, float, void>;
298+
using LocalFunctor = DistAdamFunctor<float, GRAD_T, void>;
298299
LocalFunctor::local_step(
299300
reinterpret_cast<float *>(local_p), local_m, local_v, local_g, grad_scale,
300301
beta1, beta2, beta1_correction, beta2_correction,
@@ -349,12 +350,9 @@ void multi_tensor_fused_adam_cuda(
349350
// Expect p_in, m, v, g, p_out
350351
size_t tl_sz = tensor_lists.size();
351352
TORCH_CHECK(tl_sz == 5, "expected tensor lists of size 5");
352-
353-
// Assume p_in and g have same type
354-
auto p_in_type = tensor_lists[0][0].scalar_type();
355-
auto g_type = tensor_lists[3][0].scalar_type();
356-
auto p_out_type = tensor_lists[4][0].scalar_type();
357-
TORCH_CHECK(p_in_type == g_type, "expected main params and grads to have same type");
353+
const auto p_in_type = tensor_lists[0][0].scalar_type();
354+
const auto g_type = tensor_lists[3][0].scalar_type();
355+
const auto p_out_type = tensor_lists[4][0].scalar_type();
358356

359357
float beta1_correction = 1.0f, beta2_correction = 1.0f;
360358
if (bias_correction == 1) {
@@ -363,23 +361,24 @@ void multi_tensor_fused_adam_cuda(
363361
}
364362

365363
DISPATCH_FLOAT_HALF_AND_BFLOAT(p_in_type, 0, "dist_adam_cuda_kernel",
366-
DISPATCH_FLOAT_HALF_AND_BFLOAT(p_out_type, 1, "dist_adam_cuda_kernel",
367-
multi_tensor_apply<5>(
368-
BLOCK_SIZE,
369-
chunk_size,
370-
noop_flag,
371-
tensor_lists,
372-
DistAdamFunctor<scalar_t_0, scalar_t_0, scalar_t_1>(),
373-
grad_scale.DATA_PTR<float>(),
374-
beta1,
375-
beta2,
376-
beta1_correction,
377-
beta2_correction,
378-
eps,
379-
lr,
380-
(adamMode_t) mode,
381-
weight_decay);
382-
));
364+
DISPATCH_FLOAT_HALF_AND_BFLOAT(g_type, 1, "dist_adam_cuda_kernel",
365+
DISPATCH_FLOAT_HALF_AND_BFLOAT(p_out_type, 2, "dist_adam_cuda_kernel",
366+
multi_tensor_apply<5>(
367+
BLOCK_SIZE,
368+
chunk_size,
369+
noop_flag,
370+
tensor_lists,
371+
DistAdamFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(),
372+
grad_scale.data_ptr<float>(),
373+
beta1,
374+
beta2,
375+
beta1_correction,
376+
beta2_correction,
377+
eps,
378+
lr,
379+
(adamMode_t) mode,
380+
weight_decay);
381+
)));
383382
C10_CUDA_CHECK(cudaGetLastError());
384383
}
385384

@@ -402,27 +401,30 @@ void multi_tensor_fused_adam_with_param_remainders_cuda(
402401
// Expect p_in, p_rem, m, v, g, p_out
403402
size_t tl_sz = tensor_lists.size();
404403
TORCH_CHECK(tl_sz == 6, "expected tensor lists of size 6");
404+
const auto g_type = tensor_lists[4][0].scalar_type();
405405

406406
float beta1_correction = 1.0f, beta2_correction = 1.0f;
407407
if (bias_correction == 1) {
408408
beta1_correction = 1 - std::pow(beta1, step);
409409
beta2_correction = 1 - std::pow(beta2, step);
410410
}
411411

412-
multi_tensor_apply<6>(
413-
BLOCK_SIZE,
414-
chunk_size,
415-
noop_flag,
416-
tensor_lists,
417-
DistAdamWithParamRemaindersFunctor(),
418-
grad_scale.DATA_PTR<float>(),
419-
beta1,
420-
beta2,
421-
beta1_correction,
422-
beta2_correction,
423-
eps,
424-
lr,
425-
(adamMode_t) mode,
426-
weight_decay);
412+
DISPATCH_FLOAT_HALF_AND_BFLOAT(g_type, 0, "dist_adam_with_param_remainders_cuda_kernel",
413+
multi_tensor_apply<6>(
414+
BLOCK_SIZE,
415+
chunk_size,
416+
noop_flag,
417+
tensor_lists,
418+
DistAdamWithParamRemaindersFunctor<scalar_t_0>(),
419+
grad_scale.data_ptr<float>(),
420+
beta1,
421+
beta2,
422+
beta1_correction,
423+
beta2_correction,
424+
eps,
425+
lr,
426+
(adamMode_t) mode,
427+
weight_decay);
428+
);
427429
C10_CUDA_CHECK(cudaGetLastError());
428430
}

apex/contrib/optimizers/distributed_fused_adam.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -444,11 +444,6 @@ def __init__(self,
444444
f'grad_sync_dtype={grad_sync_dtype}, '
445445
f'param_sync_dtype={param_sync_dtype}))'
446446
)
447-
if grad_sync_dtype != dtype:
448-
raise RuntimeError(
449-
'DistributedFusedAdam requires dtype to match grad dtype '
450-
f'(dtype={dtype}, grad_sync_dtype={grad_sync_dtype})'
451-
)
452447
self.dtype = dtype
453448
self.grad_sync_dtype = grad_sync_dtype
454449
self.param_sync_dtype = param_sync_dtype
@@ -488,13 +483,7 @@ def __init__(self,
488483
f'distributed process group size = {self.distributed_size}, '
489484
f'redundant process group size = {self.redundant_size})'
490485
)
491-
try:
492-
self._process_group_ranks = [
493-
get_global_rank(self.process_group, local_rank)
494-
for local_rank in range(self.distributed_size)
495-
]
496-
except:
497-
self._process_group_ranks = list(range(self.distributed_size))
486+
self.process_group_root = get_global_rank(self.process_group, 0)
498487

499488
# Use average reduction for grad sync
500489
self.average_grad_sync = average_grad_sync
@@ -515,14 +504,12 @@ def __init__(self,
515504
'with store_params=True and store_param_remainders=True'
516505
)
517506
if (self.dtype != torch.float32
518-
or self.grad_sync_dtype != torch.float32
519507
or self.param_sync_dtype != torch.bfloat16):
520508
raise RuntimeError(
521509
'DistributedFusedAdam requires '
522510
'BF16 params and FP32 optimizer state '
523511
'when storing parameter remainders '
524512
f'(dtype={self.dtype}, '
525-
f'grad_sync_dtype={self.grad_sync_dtype}, '
526513
f'param_sync_dtype={self.param_sync_dtype}))'
527514
)
528515
self.store_params = store_params
@@ -565,7 +552,7 @@ def __init__(self,
565552

566553
# Scale by factor before optimizer step. Used for grad
567554
# clipping and gradient scaler.
568-
self._grad_scale = torch.full([], 1.0, dtype=self.dtype, device=self.device)
555+
self._grad_scale = torch.full([], 1.0, dtype=torch.float32, device=self.device)
569556
# Norm of parameter gradients. Used for gradient clipping and
570557
# gradient scaler.
571558
self._grad_norm = None
@@ -603,7 +590,7 @@ def _broadcast_params(self):
603590
sync_requests.append(
604591
torch.distributed.broadcast(
605592
param,
606-
src=self._process_group_ranks[0],
593+
src=self.process_group_root,
607594
group=process_group,
608595
async_op=True,
609596
)
@@ -829,7 +816,7 @@ def _init_grad_buffer(self):
829816
buffer_size = 0
830817
self._grad_buffer = torch.zeros(
831818
[buffer_size],
832-
dtype=self.dtype,
819+
dtype=self.grad_sync_dtype,
833820
device=self.device,
834821
)
835822

@@ -1089,7 +1076,7 @@ def zero_grad(self, set_to_none=False):
10891076
param.grad.zero_()
10901077

10911078
# Reset other state
1092-
self._grad_scale = torch.full([], 1.0, dtype=self.dtype, device=self.device)
1079+
self._grad_scale = torch.full([], 1.0, dtype=torch.float32, device=self.device)
10931080
self._grad_norm = None
10941081
self._dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=self.device)
10951082

@@ -1934,7 +1921,6 @@ def state_dict(self, gather_on_root=True):
19341921
ranks on the root rank (default: True)
19351922
19361923
"""
1937-
### TODO Fix
19381924
state_dict = super().state_dict()
19391925
if not gather_on_root:
19401926
return state_dict
@@ -2036,14 +2022,14 @@ def state_dict(self, gather_on_root=True):
20362022
torch.distributed.gather(
20372023
gathered_chunks[0],
20382024
gathered_chunks,
2039-
dst=self._process_group_ranks[0],
2025+
dst=self.process_group_root,
20402026
group=self.process_group,
20412027
**no_copy_kwarg,
20422028
)
20432029
else:
20442030
torch.distributed.gather(
20452031
chunk,
2046-
dst=self._process_group_ranks[0],
2032+
dst=self.process_group_root,
20472033
group=self.process_group,
20482034
)
20492035
stream.wait_stream(main_stream)

apex/contrib/test/optimizers/test_dist_adam.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def make_models(
3333
adam_w_mode=True,
3434
model_dtype=torch.float32,
3535
optim_dtype=None,
36+
grad_sync_dtype=None,
3637
param_sync_dtype=None,
3738
device='cuda',
3839
overlap_communication=True,
@@ -79,6 +80,7 @@ def make_models(
7980
overlap_param_sync=overlap_communication,
8081
bucket_cap_mb=71/(4*1024*1024),
8182
dtype=optim_dtype,
83+
grad_sync_dtype=grad_sync_dtype,
8284
param_sync_dtype=param_sync_dtype,
8385
contiguous_param_buffer=contiguous_buffers,
8486
contiguous_grad_buffer=contiguous_buffers,
@@ -117,6 +119,7 @@ def test_matches_pytorch(
117119
use_nosync=True,
118120
model_dtype=torch.float32,
119121
optim_dtype=None,
122+
grad_sync_dtype=None,
120123
param_sync_dtype=None,
121124
device='cuda',
122125
contiguous_buffers=False,
@@ -133,6 +136,7 @@ def test_matches_pytorch(
133136
adam_w_mode=adam_w_mode,
134137
model_dtype=model_dtype,
135138
optim_dtype=optim_dtype,
139+
grad_sync_dtype=grad_sync_dtype,
136140
param_sync_dtype=param_sync_dtype,
137141
device=device,
138142
overlap_communication=overlap_communication,
@@ -239,6 +243,16 @@ def test_matches_pytorch_fp16_params(self):
239243
store_params=True,
240244
)
241245

246+
def test_matches_pytorch_bf16_grads(self):
247+
self.test_matches_pytorch(
248+
rtol=5e-2,
249+
atol=1e-5,
250+
micro_batch_steps=1,
251+
model_dtype=torch.float32,
252+
optim_dtype=torch.float32,
253+
grad_sync_dtype=torch.bfloat16,
254+
)
255+
242256
def test_matches_pytorch_bf16_param_remainders(self):
243257
self.test_matches_pytorch(
244258
rtol=5e-2,

0 commit comments

Comments
 (0)