Enhance Distributed Fused Adam#1794
Conversation
| int chunk_idx = tl.block_to_chunk[blockIdx.x]; | ||
| int n = tl.sizes[tensor_loc]; | ||
|
|
||
| const float grad_scale = *grad_scale_ptr; |
There was a problem hiding this comment.
nit: might better to check whether or not grad_scale_ptr is nullptr
There was a problem hiding this comment.
Added some assertion at the beginning.
| local_p_out[ii] = static_cast<PARAM_OUT_T>(local_p[ii]); | ||
| } | ||
|
|
||
| // Store |
There was a problem hiding this comment.
would there be any appetite to use gradients after step? if so, it'd be necessary to store unscaled gradients as well.
There was a problem hiding this comment.
We haven't encounter such cases yet, after optimizer stepping a new iteration starts, where the gradients will be zero out first. This is the same behavior as other optimizer kernels in the repo, so I think we can leave it as it is until there're cases we need to store unscaled gradients in future.
| self.state["step"] += 1 if not self.capturable else \ | ||
| (self._dummy_overflow_buf != 1).to(torch.int) |
There was a problem hiding this comment.
Q: where would we decrement this value when self.capturable is True and invalid grads are found?
There was a problem hiding this comment.
As you know self.state["step"] is to track how many steps the optimizer has advanced, it is used for bias correction in the CUDA kernel. When invalid grads are found, self._dummy_overflow_buf is 1, then it's self.state["step"] += 0, otherwise it's self.state["step"] += 1. We don't need to decrement it in such form.
There was a problem hiding this comment.
uh, I misread it, thank you for correcting me.
another question: would this be really host-device sync free?
There was a problem hiding this comment.
Yes, in this case distributed fused adam is sync-free.
There was a problem hiding this comment.
As we've discussed, the grad clipping behavior is incorrect because plain PyTorch optimizers don't handle grad scaling gracefully:
# Plain PyTorch
torch.nn.clip_grad_norm_(model.parameters())
scaler.step(optim) # Clipped grads are scaled
# Distributed optimizer
optim.clip_grad_norm()
scaler.step(optim) # Clipped grads are scaledI'd prefer if distopt were as close as possible to a drop-in optimizer replacement, so I don't think the current behavior should be changed.
Supporting correct grad clipping is important though. I propose the following API:
# Plain PyTorch
scaler.unscale_(optim)
torch.nn.clip_grad_norm_(model.parameters())
scaler.step(optim) # Grads are not scaled
# Distributed optimizer
optim.unscale_grads(grad_scaler=scaler)
optim.clip_grad_norm()
scaler.step(optim) # Grads are not scaled|
@timmoon10, I like the idea of drop-in optimizer replacement. Right now, distributed fused adam sets To support the idea you mentioned, I need |
|
I've implemented my proposed API at timmoon10@0fa8e3a, although I haven't been able to test yet. NeMo GPT avoided these issues because it implemented a custom
|
|
But there's also the issue when Overlapping reduce-scatter with bprop is quite important to the performance, so I think it is necessary. |
|
I see, we need |
… copy after all-gather.
0304852 to
c466e70
Compare
This PR enhances distributed fused adam by:
@timmoon10 @crcrpar Please help review, thanks.