From 4138d31ff0acf4071d1dc001ccb7cd6e00800324 Mon Sep 17 00:00:00 2001 From: Wil Kong Date: Sun, 28 Apr 2024 12:50:32 +0800 Subject: [PATCH] Enhance Distributed Fused Adam (#1794) --- .../optimizers/multi_tensor_distopt_adam.cpp | 18 ++ .../multi_tensor_distopt_adam_kernel.cu | 214 ++++++++++++++++++ .../optimizers/distributed_fused_adam.py | 157 +++++++++++-- 3 files changed, 372 insertions(+), 17 deletions(-) diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp index 5be8b2840..2a2a878f0 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp @@ -14,6 +14,20 @@ void multi_tensor_fused_adam_cuda( int bias_correction, float weight_decay); +void multi_tensor_fused_adam_capturable_cuda( + int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor grad_scale, + at::Tensor lr, + float beta1, + float beta2, + float eps, + at::Tensor step, + int mode, + int bias_correction, + float weight_decay); + void multi_tensor_fused_adam_with_param_remainders_cuda( int chunk_size, at::Tensor noop_flag, @@ -33,6 +47,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &multi_tensor_fused_adam_cuda, "CUDA kernels for multi-tensor Adam, " "with param copy"); + m.def("multi_tensor_fused_adam_capturable", + &multi_tensor_fused_adam_capturable_cuda, + "CUDA kernels for multi-tensor Adam, " + "with param copy, capturable for CUDA graph"); m.def("multi_tensor_fused_adam_with_param_remainders", &multi_tensor_fused_adam_with_param_remainders_cuda, "CUDA kernel for multi-tensor Adam, " diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu index f769e57c2..54c0eb6ca 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu @@ -194,6 +194,175 @@ struct DistAdamFunctor } }; +/* Multi-tensor Adam with CUDA Graph Support + * + * Updates params in-place and outputs a copy with a desired datatype. + */ +template +struct DistAdamCapturableFunctor +{ + // Vectorized local compute + __device__ __forceinline__ static void local_step( + T p[ILP], + T m[ILP], + T v[ILP], + const GRAD_T g[ILP], + const float grad_scale, + const float beta1, + const float beta2, + const float beta1_correction, + const float beta2_correction, + const float eps, + const float lr, + adamMode_t mode, + const float weight_decay) { + if (mode == ADAM_MODE_0) { // L2 +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float scaled_grad = (g[ii] * grad_scale) + (weight_decay * p[ii]); + float next_m = lerp(beta1, scaled_grad, m[ii]); + float next_v = lerp(beta2, scaled_grad*scaled_grad, v[ii]); + float next_m_unbiased = next_m / beta1_correction; + float next_v_unbiased = next_v / beta2_correction; + float denom = sqrtf(next_v_unbiased) + eps; + float update = next_m_unbiased / denom; + m[ii] = next_m; + v[ii] = next_v; + p[ii] -= lr * update; + } + } else { // weight decay +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float scaled_grad = g[ii] * grad_scale; + float next_m = lerp(beta1, scaled_grad, m[ii]); + float next_v = lerp(beta2, scaled_grad*scaled_grad, v[ii]); + float next_m_unbiased = next_m / beta1_correction; + float next_v_unbiased = next_v / beta2_correction; + float denom = sqrtf(next_v_unbiased) + eps; + float update = (next_m_unbiased / denom) + (weight_decay * p[ii]); + m[ii] = next_m; + v[ii] = next_v; + p[ii] -= lr * update; + } + } + } + + __device__ __forceinline__ void operator()( + int chunk_size, + volatile int* noop_gmem, + TensorListMetadata<5>& tl, + const float* grad_scale_ptr, + const float beta1, + const float beta2, + const int* step, + const int bias_correction, + const float eps, + const float* lr, + adamMode_t mode, + const float weight_decay) const + { + assert(noop_gmem); + assert(grad_scale_ptr); + assert(step); + assert(lr); + + if(*noop_gmem == 1) + return; + + float beta1_correction = 1.0f, beta2_correction = 1.0f; + if (bias_correction == 1) { + beta1_correction = 1 - pow(beta1, *step); + beta2_correction = 1 - pow(beta2, *step); + } + + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + const float grad_scale = *grad_scale_ptr; + + T* p_in = (T *)tl.addresses[0][tensor_loc]; + p_in += chunk_idx*chunk_size; + T* m = (T *)tl.addresses[1][tensor_loc]; + m += chunk_idx*chunk_size; + T* v = (T *)tl.addresses[2][tensor_loc]; + v += chunk_idx*chunk_size; + const GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc]; + g += chunk_idx*chunk_size; + PARAM_OUT_T* p_out = (PARAM_OUT_T *)tl.addresses[4][tensor_loc]; + p_out += chunk_idx*chunk_size; + + n -= chunk_idx*chunk_size; + n = chunk_size < n ? chunk_size : n; + + const bool aligned = (n % ILP == 0 && + is_aligned(p_in) && + is_aligned(m) && + is_aligned(v) && + is_aligned(g) && + is_aligned(p_out)); + + for (int i_start = threadIdx.x*ILP; i_start < n; i_start += blockDim.x*ILP) { + T local_p[ILP]; + T local_m[ILP]; + T local_v[ILP]; + GRAD_T local_g[ILP]; + PARAM_OUT_T local_p_out[ILP]; + + // Load + if (aligned) { + load_store(local_p, p_in + i_start); + load_store(local_m, m + i_start); + load_store(local_v, v + i_start); + load_store(local_g, g + i_start); + } else { +#pragma unroll + for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { + if (i < n) { + local_p[ii] = p_in[i]; + local_m[ii] = m[i]; + local_v[ii] = v[i]; + local_g[ii] = g[i]; + } else { + local_p[ii] = 0; + local_m[ii] = 0; + local_v[ii] = 0; + local_g[ii] = 0; + } + } + } + + // Local compute + local_step( + local_p, local_m, local_v, local_g, grad_scale, + beta1, beta2, beta1_correction, beta2_correction, + eps, *lr, mode, weight_decay); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + local_p_out[ii] = static_cast(local_p[ii]); + } + + // Store + if (aligned) { + load_store(p_in + i_start, local_p); + load_store(m + i_start, local_m); + load_store(v + i_start, local_v); + load_store(p_out + i_start, local_p_out); + } else { +#pragma unroll + for (int ii = 0, i = i_start; ii < ILP; ii++, i++) { + if (i < n) { + p_in[i] = local_p[ii]; + m[i] = local_m[ii]; + v[i] = local_v[ii]; + p_out[i] = local_p_out[ii]; + } + } + } + } + } +}; + /* Functor for multi-tensor Adam with implicit main params * * If params are BF16 and optimizer state is FP32, it is not necessary @@ -382,6 +551,51 @@ void multi_tensor_fused_adam_cuda( C10_CUDA_CHECK(cudaGetLastError()); } +void multi_tensor_fused_adam_capturable_cuda( + int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, // p_in, m, v, g, p_out + at::Tensor grad_scale, + at::Tensor lr, + float beta1, + float beta2, + float eps, + at::Tensor step, + int mode, + int bias_correction, + float weight_decay) +{ + using namespace at; + + // Expect p_in, m, v, g, p_out + size_t tl_sz = tensor_lists.size(); + TORCH_CHECK(tl_sz == 5, "expected tensor lists of size 5"); + const auto p_in_type = tensor_lists[0][0].scalar_type(); + const auto g_type = tensor_lists[3][0].scalar_type(); + const auto p_out_type = tensor_lists[4][0].scalar_type(); + + DISPATCH_FLOAT_HALF_AND_BFLOAT(p_in_type, 0, "dist_adam_capturable_cuda_kernel", + DISPATCH_FLOAT_HALF_AND_BFLOAT(g_type, 1, "dist_adam_capturable_cuda_kernel", + DISPATCH_FLOAT_HALF_AND_BFLOAT(p_out_type, 2, "dist_adam_capturable_cuda_kernel", + multi_tensor_apply<5>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + DistAdamCapturableFunctor(), + grad_scale.data_ptr(), + beta1, + beta2, + step.data_ptr(), + bias_correction, + eps, + lr.data_ptr(), + (adamMode_t) mode, + weight_decay); + ))); + C10_CUDA_CHECK(cudaGetLastError()); +} + void multi_tensor_fused_adam_with_param_remainders_cuda( int chunk_size, at::Tensor noop_flag, diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 1c5f98a83..e653fbe41 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -219,6 +219,7 @@ def _multi_tensor_copy( list(zip(*buffers)), ) else: + # Warning: dummy_overflow_buf was not set in such case for buf_in, buf_out in buffers: buf_out.copy_(buf_in) @@ -372,6 +373,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): when IB SHARP is enabled in a one-rank-per-node communication group. This will help speedup the gemms overlapped with data- parallel communications. + capturable (bool, optional): whether to use the version of the + optimizer that can be used with CUDA Graphs. (default: False). .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -565,7 +568,24 @@ def __init__( store_param_remainders: bool = False, with_scaled_states: bool = False, nccl_ub: bool = False, + capturable: bool = False, ): + if (with_scaled_states or store_param_remainders) and capturable: + raise Exception(f"{self.__class__.__name__} with scaled states " + "or storing param remainders doesn't support CUDA graph yet.") + + if capturable and not _FOUND_DEPRECATED_FUSED_ADAM: + raise Exception(f"Capturable {self.__class__.__name__} relies on " + "multi_tensor_copy to set dummy_overflow_buf to indicate " + "whether there's gradient Inf/NaN, build APEX with " + "`--deprecated_fused_adam` is essential.") + + # If capturable for CUDA graph + self.capturable: bool = capturable + # If the optimizer is capturable then LR should be a tensor (on GPU) + lr: torch.Tensor | float = torch.tensor(lr, dtype=torch.float32) \ + if capturable else lr + defaults = dict( lr=lr, bias_correction=bias_correction, @@ -577,6 +597,7 @@ def __init__( # Adam options self.adam_w_mode: bool = adam_w_mode + self.amsgrad: bool = amsgrad if amsgrad: raise RuntimeError( "DistributedFusedAdam does not support the AMSGrad variant." @@ -709,6 +730,7 @@ def __init__( # Determine bucket sizes dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8 self.alignment: int = 128 // dtype_size + self.bucket_cap_mb: float = bucket_cap_mb bucket_size = 1024 * 1024 * bucket_cap_mb / dtype_size shard_size = int(bucket_size / self.distributed_size) shard_size = _round_to_multiple(shard_size, self.alignment, round_up=False) @@ -717,7 +739,8 @@ def __init__( # Optimizer state self.state["buckets"]: List[StateBucket] = [] - self.state["step"]: int = 0 + self.state["step"]: torch.Tensor | int = torch.tensor([0], dtype=torch.int, + device=self.device) if self.capturable else 0 # Gradient state self._grads_buckets: Dict[int, GradientBucket] = collections.defaultdict( @@ -790,6 +813,43 @@ def __init__( if self.overlap_param_sync: self._register_pre_forward_hooks() + # Move LR to device + if capturable: + for idx, group in enumerate(self.param_groups): + if len(group['params']) == 0: + continue + for item in ['lr']: + self.param_groups[idx][item] = group[item].to(device=self.device) + + # For better representation string + arg_names = inspect.getfullargspec(DistributedFusedAdam.__init__).args + arg_names.remove('self') + arg_names.remove('params') + for i, group in enumerate(self.param_groups): + for key in sorted(group.keys()): + if key in arg_names: + arg_names.remove(key) + self.args_dict = {name: getattr(self, name) for name in arg_names} + + def __repr__(self) -> str: + # Based on: https://github.com/pytorch/pytorch/blob/v2.3.0-rc12/torch/optim/optimizer.py#L315 + format_string = self.__class__.__name__ + ' (' + for i, group in enumerate(self.param_groups): + format_string += '\n' + format_string += f'Parameter Group {i}\n' + for key in sorted(group.keys()): + if key != 'params': + format_string += f' {key}: {group[key]}\n' + + for key, val in self.args_dict.items(): + if 'process_group' in key and val: + format_string += f'{key}: {hex(id(val))}, world size {val.size()}\n' + else: + format_string += f'{key}: {val}\n' + + format_string += ')' + return format_string + @torch.no_grad() def _broadcast_params(self) -> None: """Broadcast parameter values from root rank""" @@ -1082,6 +1142,8 @@ def init_param_buffer(self) -> None: f"Attempted to change a parameter with dtype={param.dtype} " f"into a buffer view with dtype={param_buffer_view.dtype}" ) + if param.is_contiguous(memory_format=torch.channels_last): + param = param.permute(0, 2, 3, 1) param_flat_views.append(param.detach().view(-1)) param_buffer_views.append(param_buffer_view) @@ -1094,7 +1156,15 @@ def init_param_buffer(self) -> None: # Make all params a view into the param buffer for param, buffer_view in zip(params, param_buffer_views): - param.data = buffer_view.view(param.size()) + # Preserve memory format for param here, i.e. NHWC tensors + # `param.data.set_()` failed to change storage. + # `param.set_()` invalidates bprop hook. + param.data = torch.as_strided( + buffer_view, + param.size(), + param.stride(), + storage_offset=buffer_view.storage_offset(), + ) def _init_grad_buffer(self) -> None: """Allocate contiguous buffer for grad buckets""" @@ -1465,6 +1535,13 @@ def make_bucket( continue bucket_id = fragment.bucket_id bucket = self.state["buckets"][bucket_id] + # If param is channels last, i.e. tensor with shape (N, C, H, W) + # and stride (HWC, 1, WC, C), then we will turn it into a tensor + # with shape (N, H, W, C) and stride (HWC, WC, C, 1). The purppose + # is to avoid failures when flattening the tensor (`.view(-1)`) + # and stepping the optimizer. + if param.is_contiguous(memory_format=torch.channels_last): + param = param.permute(0, 2, 3, 1) param_range = slice(*fragment.shard_param_range) shard_range = slice(*fragment.shard_range) model_param_fragment = param.detach().view(-1)[param_range] @@ -1540,11 +1617,9 @@ def zero_grad(self, set_to_none: bool = False) -> None: param.grad.zero_() # Reset other state - self._grad_scale = torch.full([], 1.0, dtype=torch.float32, device=self.device) + self._grad_scale.fill_(1.0) self._grad_norm = None - self._dummy_overflow_buf = torch.zeros( - [1], dtype=torch.int32, device=self.device - ) + self._dummy_overflow_buf.zero_() def _grad_copy(self, param: torch.nn.Parameter) -> None: """Copy parameter gradients to gradient buckets @@ -1605,7 +1680,11 @@ def _grad_copy(self, param: torch.nn.Parameter) -> None: # Copy param grad to bucket if param.grad is not None: - grad_in = param.grad.detach().view(-1)[grad_start:grad_end] + if param.grad.is_contiguous(memory_format=torch.channels_last): + grad_in = param.grad.permute(0, 2, 3, 1) + else: + grad_in = param.grad + grad_in = grad_in.detach().view(-1)[grad_start:grad_end] grad_out = bucket.grads_bucket[bucket_start:bucket_end] if grad_in.data_ptr() != grad_out.data_ptr(): grad_out.add_(grad_in) @@ -1684,6 +1763,13 @@ def _param_copy_fragments( # Corresponding positions in param bucket and param bucket = self._params_buckets[bucket_id] param = self.parameter(fragment) + + # Conv with NHWC layout, i.e. shape (N, C, H, W) and stride + # (HWC, 1, WC, C), can't `.view(-1)`. Here to turn it to + # tensor with shape (N, H, W, C) and stride (HWC, WC, C, 1). + if param.is_contiguous(memory_format=torch.channels_last): + param = param.permute(0, 2, 3, 1) + buffer_in = bucket.params_bucket[bucket_start:bucket_end] buffer_out = param.detach().view(-1)[param_start:param_end] @@ -1735,9 +1821,16 @@ def grad_buffer_view(self, param: torch.nn.Parameter) -> torch.Tensor: buffer_end = buffer_start + param.numel() # Construct view into grad buffer + # Preserve memory format for gradient here flat_buffer = self._grad_buffers[bucket.dtypes()] - flat_buffer = flat_buffer[buffer_start:buffer_end] - return flat_buffer.detach().view(param.size()) + grad = torch.empty(1, dtype=param.dtype, device=param.device) + grad.set_( + source=flat_buffer, + storage_offset=buffer_start, + size=param.size(), + stride=param.stride(), + ) + return grad def _force_bucket_grad_sync(self) -> None: """Ensure that all gradient buckets are synchronized""" @@ -2225,8 +2318,12 @@ def grad_norm( group=self.distributed_process_group, ) self._grad_norm = grad_norm_sq.sqrt() - grad_norm = self._grad_norm * self._grad_scale - return grad_norm.detach() + if hasattr(self, "_step_supports_amp_scaling") and self._step_supports_amp_scaling: + return self._grad_norm.detach() + else: + # Notice that update of self._grad_scale changes grad_norm. + grad_norm = self._grad_norm * self._grad_scale + return grad_norm.detach() def clip_grad_norm( self, @@ -2254,9 +2351,15 @@ def clip_grad_norm( """ assert max_norm > 0 total_norm = self.grad_norm(parameters=parameters, norm_type=norm_type) - clip_coef = max_norm / (total_norm + 1e-6) - clip_coef_clamped = torch.clamp(clip_coef, max=1.0) - self._grad_scale *= clip_coef_clamped + if hasattr(self, "_step_supports_amp_scaling") and self._step_supports_amp_scaling: + # Gradients haven't been unscaled yet, thus total_norm is + # `scaler._scale` times larger and clip_coef_clamped is wrong. + # Thus the clipping must be deferred to optimizer step. + self._max_norm = max_norm + else: + clip_coef = max_norm / (total_norm + 1e-6) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + self._grad_scale *= clip_coef_clamped return total_norm def unscale_grads(self, inv_scale: torch.Tensor, *args): @@ -2268,6 +2371,12 @@ def unscale_grads(self, inv_scale: torch.Tensor, *args): inv_scale (torch.Tensor): factor to multiply gradients """ + # When `_step_supports_amp_scaling` set, gradient unscaling is performed + # in the optimizer step function. + if hasattr(self, "_step_supports_amp_scaling") and self._step_supports_amp_scaling: + raise Exception("`_step_supports_amp_scaling` was set, gradient " + "unscaling is expected to be applied in the optimizer step function.") + self._grad_scale *= inv_scale.view([]) return {self.device: torch.zeros(1, dtype=torch.float32, device=self.device)} @@ -2312,10 +2421,19 @@ def step( assert grad_scaler._scale is not None self._grad_scale /= grad_scaler._scale.view([]) grad_norm = self.grad_norm() + if hasattr(self, "_step_supports_amp_scaling") and self._step_supports_amp_scaling: + # Gradient norm was computed before gradient unscaling. + grad_norm = grad_norm / grad_scaler._scale.view([]) + clip_coef = self._max_norm / (grad_norm + 1e-6) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + self._grad_scale *= clip_coef_clamped + found_inf = torch.logical_not(torch.isfinite(grad_norm)) scaler_state = grad_scaler._per_optimizer_states[id(self)] scaler_state["found_inf_per_device"] = {found_inf.device: found_inf.float()} - if found_inf.item(): + if self.capturable: + self._dummy_overflow_buf.copy_(found_inf) + elif found_inf.item(): return self._grad_scale = self._grad_scale.to(dtype=torch.float32, device=self.device) @@ -2378,7 +2496,8 @@ def step( ) # Apply optimizer step - self.state["step"] += 1 + self.state["step"] += 1 if not self.capturable else \ + (self._dummy_overflow_buf != 1).to(torch.int) overlap_first_bucket = ( self.distributed_size > 1 and self.overlap_param_sync @@ -2485,6 +2604,8 @@ def _local_step(self, bucket_ids: List[int]) -> None: shard_range = slice(shard_start, shard_end) if state_bucket.params_shard is None: param = self.parameter(fragment) + if param.is_contiguous(memory_format=torch.channels_last): + param = param.permute(0, 2, 3, 1) param_range = slice(*fragment.shard_param_range) param_fragment = param.detach().view(-1)[param_range] param_fragment = param_fragment.to( @@ -2510,11 +2631,13 @@ def _local_step(self, bucket_ids: List[int]) -> None: ) # Apply optimizer step to each param group + adam_func = distributed_adam_cuda.multi_tensor_fused_adam_capturable \ + if self.capturable else distributed_adam_cuda.multi_tensor_fused_adam for (group_id, _, _, _), group_buffers in buffers.items(): group = self.param_groups[group_id] beta1, beta2 = group["betas"] multi_tensor_applier( - distributed_adam_cuda.multi_tensor_fused_adam, + adam_func, self._dummy_overflow_buf, list(zip(*group_buffers)), self._grad_scale,