Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Enhance Distributed Fused Adam#1794

Merged
Aidyn-A merged 10 commits into
NVIDIA:24.04.01-develfrom
alpha0422:wkong/dist-adam
Apr 28, 2024
Merged

Enhance Distributed Fused Adam#1794
Aidyn-A merged 10 commits into
NVIDIA:24.04.01-develfrom
alpha0422:wkong/dist-adam

Conversation

@alpha0422
Copy link
Copy Markdown
Contributor

This PR enhances distributed fused adam by:

  1. Support NHWC layout (required by some Conv related models, e.g. Diffusion models);
  2. Fix the gradient clipping bug;
  3. Support CUDA graph;

@timmoon10 @crcrpar Please help review, thanks.

int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

const float grad_scale = *grad_scale_ptr;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: might better to check whether or not grad_scale_ptr is nullptr

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some assertion at the beginning.

local_p_out[ii] = static_cast<PARAM_OUT_T>(local_p[ii]);
}

// Store
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would there be any appetite to use gradients after step? if so, it'd be necessary to store unscaled gradients as well.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't encounter such cases yet, after optimizer stepping a new iteration starts, where the gradients will be zero out first. This is the same behavior as other optimizer kernels in the repo, so I think we can leave it as it is until there're cases we need to store unscaled gradients in future.

Comment on lines +2395 to +2500
self.state["step"] += 1 if not self.capturable else \
(self._dummy_overflow_buf != 1).to(torch.int)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: where would we decrement this value when self.capturable is True and invalid grads are found?

Copy link
Copy Markdown
Contributor Author

@alpha0422 alpha0422 Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you know self.state["step"] is to track how many steps the optimizer has advanced, it is used for bias correction in the CUDA kernel. When invalid grads are found, self._dummy_overflow_buf is 1, then it's self.state["step"] += 0, otherwise it's self.state["step"] += 1. We don't need to decrement it in such form.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh, I misread it, thank you for correcting me.

another question: would this be really host-device sync free?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in this case distributed fused adam is sync-free.

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we've discussed, the grad clipping behavior is incorrect because plain PyTorch optimizers don't handle grad scaling gracefully:

# Plain PyTorch
torch.nn.clip_grad_norm_(model.parameters())
scaler.step(optim)  # Clipped grads are scaled

# Distributed optimizer
optim.clip_grad_norm()
scaler.step(optim)  # Clipped grads are scaled

I'd prefer if distopt were as close as possible to a drop-in optimizer replacement, so I don't think the current behavior should be changed.

Supporting correct grad clipping is important though. I propose the following API:

# Plain PyTorch
scaler.unscale_(optim)
torch.nn.clip_grad_norm_(model.parameters())
scaler.step(optim)  # Grads are not scaled

# Distributed optimizer
optim.unscale_grads(grad_scaler=scaler)
optim.clip_grad_norm()
scaler.step(optim)  # Grads are not scaled

@alpha0422
Copy link
Copy Markdown
Contributor Author

@timmoon10, I like the idea of drop-in optimizer replacement. Right now, distributed fused adam sets _step_supports_amp_scaling, so scaler.unscale_(optim) or optim.unscale_grads(grad_scaler=scaler) won't be called from PyTorch or PyTorch Lightning, because the assumption of _step_supports_amp_scaling is the gradient unscaling will be done in the optimizer step function, thus gradient clipping need to be delayed to the optimizer step function too.

To support the idea you mentioned, I need _step_supports_amp_scaling need to be removed, but then I think it will break other use cases, and it will decrease the performance because gradient unscaling is explicit and not fused with the step kernel.

@timmoon10
Copy link
Copy Markdown
Member

timmoon10 commented Apr 11, 2024

I've implemented my proposed API at timmoon10@0fa8e3a, although I haven't been able to test yet.

NeMo GPT avoided these issues because it implemented a custom GradScaler that called DistributedFusedAdam.unscale_grads within GradScaler.unscale_: https://github.com/NVIDIA/NeMo/blob/c5738263d8b4bedb0957374116d3e90746a51c37/nemo/collections/nlp/parts/nlp_overrides.py#L1235. See #1512 and NVIDIA-NeMo/NeMo#4900.

_step_supports_amp_scaling is needed because otherwise GradScaler.unscale_ would attempt to access the parameters' .grads, which have probably already been reduce-scattered and set to None. The only way I can see to avoid this is to disable overlapping grad reduce-scatters with backward compute.

@alpha0422
Copy link
Copy Markdown
Contributor Author

But there's also the issue when _step_supports_amp_scaling set, GradScaler.unscale_ will never be called from PyTorch or PyTorch Lightning. I saw you tried to unscale at here: nlp_overrides.py#L1202, but this function was never called, I confirmed with Stable Diffusion and LLM.

Overlapping reduce-scatter with bprop is quite important to the performance, so I think it is necessary.

@timmoon10
Copy link
Copy Markdown
Member

timmoon10 commented Apr 12, 2024

I see, we need _step_support_amp_scaling=False specifically when using nemo.collections.nlp.parts.nlp_overrides.GradScaler. However, _step_support_amp_scaling=True is needed for correct behavior with torch.amp.GradScaler. I think the cleanest solution is to set _step_support_amp_scaling=False in NeMo's distopt wrapper. That helps keep the NeMo-specific logic separate from the general PyTorch logic in Apex. Reverting the changes to the grad clipping logic (e.g. with timmoon10@0fa8e3a) is needed to preserve correct behavior with torch.amp.GradScaler.

@alpha0422 alpha0422 marked this pull request as draft April 12, 2024 08:03
@alpha0422 alpha0422 marked this pull request as ready for review April 25, 2024 11:26
@Aidyn-A Aidyn-A changed the base branch from master to 24.04.01-devel April 28, 2024 04:50
@Aidyn-A Aidyn-A merged commit 4138d31 into NVIDIA:24.04.01-devel Apr 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants