|
| 1 | +#include <ATen/ATen.h> |
| 2 | +#include <ATen/AccumulateType.h> |
| 3 | +#include <ATen/cuda/CUDAContext.h> |
| 4 | + |
| 5 | + |
| 6 | +#define ASSERT_UINT4_ALIGNED(PTR) \ |
| 7 | + TORCH_INTERNAL_ASSERT(is_aligned<uint4>(PTR), "Tensor " #PTR " is not uint4 aligned") |
| 8 | + |
| 9 | +template <class T> bool is_aligned(const void *ptr) noexcept { |
| 10 | + auto iptr = reinterpret_cast<std::uintptr_t>(ptr); |
| 11 | + return !(iptr % alignof(T)); |
| 12 | +} |
| 13 | + |
| 14 | +template <bool SMOOTHING, int ILP, typename scalar_t, typename labelscalar_t, |
| 15 | + typename accscalar_t, typename outscalar_t> |
| 16 | +__global__ void focal_loss_forward_cuda_kernel( |
| 17 | + outscalar_t *loss, scalar_t *partial_grad, |
| 18 | + const scalar_t *__restrict__ cls_output, |
| 19 | + const labelscalar_t *__restrict__ cls_targets_at_level, |
| 20 | + const float *__restrict__ num_positives_sum, const int64_t num_examples, |
| 21 | + const int64_t num_classes, const int64_t num_real_classes, |
| 22 | + const float alpha, const float gamma, const float smoothing_factor) { |
| 23 | + extern __shared__ unsigned char shm[]; |
| 24 | + accscalar_t *loss_shm = reinterpret_cast<accscalar_t *>(shm); |
| 25 | + loss_shm[threadIdx.x] = 0; |
| 26 | + accscalar_t loss_acc = 0; |
| 27 | + |
| 28 | + accscalar_t one = accscalar_t(1.0); |
| 29 | + accscalar_t K = accscalar_t(2.0); |
| 30 | + accscalar_t normalizer = one / static_cast<accscalar_t>(num_positives_sum[0]); |
| 31 | + accscalar_t nn_norm, np_norm, pn_norm, pp_norm; |
| 32 | + |
| 33 | + // *_norm is used for label smoothing only |
| 34 | + if (SMOOTHING) { |
| 35 | + nn_norm = one - smoothing_factor / K; |
| 36 | + np_norm = smoothing_factor / K; |
| 37 | + pn_norm = smoothing_factor - smoothing_factor / K; |
| 38 | + pp_norm = one - smoothing_factor + smoothing_factor / K; |
| 39 | + } |
| 40 | + |
| 41 | + uint4 p_vec, grad_vec; |
| 42 | + |
| 43 | + // Accumulate loss on each thread |
| 44 | + for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x) * ILP; |
| 45 | + i < num_examples * num_classes; i += gridDim.x * blockDim.x * ILP) { |
| 46 | + int64_t idy = i / num_classes; |
| 47 | + labelscalar_t y = cls_targets_at_level[idy]; |
| 48 | + int64_t base_yid = i % num_classes; |
| 49 | + |
| 50 | + int64_t pos_idx = idy * num_classes + y; |
| 51 | + p_vec = *(uint4 *)&cls_output[i]; |
| 52 | + |
| 53 | + // Skip ignored matches |
| 54 | + if (y == -2) { |
| 55 | +#pragma unroll |
| 56 | + for (int j = 0; j < ILP; j++) { |
| 57 | + *((scalar_t *)(&grad_vec) + j) = 0; |
| 58 | + } |
| 59 | + *(uint4 *)&partial_grad[i] = grad_vec; |
| 60 | + continue; |
| 61 | + } |
| 62 | + |
| 63 | +#pragma unroll |
| 64 | + for (int j = 0; j < ILP; j++) { |
| 65 | + // Skip the pad classes |
| 66 | + if (base_yid + j >= num_real_classes) { |
| 67 | + *((scalar_t *)(&grad_vec) + j) = 0; |
| 68 | + continue; |
| 69 | + } |
| 70 | + |
| 71 | + accscalar_t p = static_cast<accscalar_t>(*((scalar_t *)(&p_vec) + j)); |
| 72 | + accscalar_t exp_np = ::exp(-p); |
| 73 | + accscalar_t exp_pp = ::exp(p); |
| 74 | + accscalar_t sigma = one / (one + exp_np); |
| 75 | + accscalar_t logee = (p >= 0) ? exp_np : exp_pp; |
| 76 | + accscalar_t addee = (p >= 0) ? 0 : -p; |
| 77 | + accscalar_t off_a = addee + ::log(one + logee); |
| 78 | + |
| 79 | + // Negative matches |
| 80 | + accscalar_t base = SMOOTHING ? nn_norm * p : p; |
| 81 | + accscalar_t off_b = (SMOOTHING ? np_norm : 0) - sigma; |
| 82 | + accscalar_t coeff_f1 = one - alpha; |
| 83 | + accscalar_t coeff_f2 = sigma; |
| 84 | + accscalar_t coeff_b1 = gamma; |
| 85 | + accscalar_t coeff_b2 = one - sigma; |
| 86 | + |
| 87 | + // Positive matches |
| 88 | + if (y >= 0 && (i + j == pos_idx)) { |
| 89 | + base = SMOOTHING ? pn_norm * p : 0; |
| 90 | + off_b = (SMOOTHING ? pp_norm : one) - sigma; |
| 91 | + coeff_f1 = alpha; |
| 92 | + coeff_f2 = one - sigma; |
| 93 | + coeff_b1 = -gamma; |
| 94 | + coeff_b2 = sigma; |
| 95 | + } |
| 96 | + |
| 97 | + accscalar_t coeff_f = coeff_f1 * ::pow(coeff_f2, gamma); |
| 98 | + accscalar_t coeff_b = coeff_b1 * coeff_b2; |
| 99 | + |
| 100 | + accscalar_t loss_t = coeff_f * (base + off_a); |
| 101 | + accscalar_t grad = coeff_f * (coeff_b * (base + off_a) - off_b); |
| 102 | + |
| 103 | + // Delay the normalize of partial gradient by num_positives_sum to back |
| 104 | + // propagation because scalar_t reduces precision. Focal loss is very |
| 105 | + // sensitive to the small gradient. No worry on overflow here since |
| 106 | + // gradient has relative smaller range than input. |
| 107 | + loss_acc += loss_t; |
| 108 | + *((scalar_t *)(&grad_vec) + j) = static_cast<scalar_t>(grad); |
| 109 | + } |
| 110 | + |
| 111 | + // This can't ensure to generate stg.128 and may be two stg.64. |
| 112 | + *(uint4 *)&partial_grad[i] = grad_vec; |
| 113 | + } |
| 114 | + loss_shm[threadIdx.x] = loss_acc; |
| 115 | + |
| 116 | + // Intra-CTA reduction |
| 117 | + __syncthreads(); |
| 118 | + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { |
| 119 | + if (threadIdx.x < s) { |
| 120 | + loss_shm[threadIdx.x] += loss_shm[threadIdx.x + s]; |
| 121 | + } |
| 122 | + __syncthreads(); |
| 123 | + } |
| 124 | + |
| 125 | + // Inter-CTA reduction |
| 126 | + if (threadIdx.x == 0) { |
| 127 | + loss_acc = loss_shm[0] * normalizer; |
| 128 | + atomicAdd(loss, loss_acc); |
| 129 | + } |
| 130 | +} |
| 131 | + |
| 132 | +template <int ILP, typename scalar_t, typename accscalar_t, |
| 133 | + typename outscalar_t> |
| 134 | +__global__ void focal_loss_backward_cuda_kernel( |
| 135 | + scalar_t *partial_grad, const outscalar_t *__restrict__ grad_output, |
| 136 | + const float *__restrict__ num_positives_sum, const uint64_t numel) { |
| 137 | + int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * ILP; |
| 138 | + |
| 139 | + accscalar_t normalizer = static_cast<accscalar_t>(grad_output[0]) / |
| 140 | + static_cast<accscalar_t>(num_positives_sum[0]); |
| 141 | + |
| 142 | + // The input is enforced to pad to use vector load, thus there's no need to |
| 143 | + // check whether the last element of ILP can out of bound. |
| 144 | + if (idx >= numel) |
| 145 | + return; |
| 146 | + |
| 147 | + uint4 grad_vec; |
| 148 | + grad_vec = *(uint4 *)&partial_grad[idx]; |
| 149 | +#pragma unroll(ILP) |
| 150 | + for (int i = 0; i < ILP; i++) { |
| 151 | + auto grad = static_cast<accscalar_t>(*((scalar_t *)(&grad_vec) + i)); |
| 152 | + grad *= normalizer; |
| 153 | + *((scalar_t *)(&grad_vec) + i) = static_cast<scalar_t>(grad); |
| 154 | + } |
| 155 | + *(uint4 *)&partial_grad[idx] = grad_vec; |
| 156 | +} |
| 157 | + |
| 158 | +std::vector<at::Tensor> focal_loss_forward_cuda( |
| 159 | + const at::Tensor &cls_output, const at::Tensor &cls_targets_at_level, |
| 160 | + const at::Tensor &num_positives_sum, const int64_t num_real_classes, |
| 161 | + const float alpha, const float gamma, const float smoothing_factor) { |
| 162 | + // Checks required for correctness |
| 163 | + TORCH_INTERNAL_ASSERT(cls_output.size(-1) >= num_real_classes, |
| 164 | + "Incorrect number of real classes."); |
| 165 | + TORCH_INTERNAL_ASSERT(cls_targets_at_level.scalar_type() == at::kLong, |
| 166 | + "Invalid label type."); |
| 167 | + TORCH_INTERNAL_ASSERT( |
| 168 | + (num_positives_sum.numel() == 1) && |
| 169 | + (num_positives_sum.scalar_type() == at::kFloat), |
| 170 | + "Expect num_positives_sum to be a float32 tensor with only one element."); |
| 171 | + TORCH_INTERNAL_ASSERT(cls_output.dim() == cls_targets_at_level.dim() + 1, |
| 172 | + "Mis-matched dimensions between class output and label."); |
| 173 | + for (int64_t i = 0; i < cls_targets_at_level.dim(); i++) |
| 174 | + TORCH_INTERNAL_ASSERT(cls_output.size(i) == cls_targets_at_level.size(i), |
| 175 | + "Mis-matched shape between class output and label."); |
| 176 | + |
| 177 | + // Checks required for better performance |
| 178 | + const int ILP = sizeof(uint4) / cls_output.element_size(); |
| 179 | + ASSERT_UINT4_ALIGNED(cls_output.data_ptr()); |
| 180 | + TORCH_INTERNAL_ASSERT(cls_output.size(-1) % ILP == 0, |
| 181 | + "Pad number of classes first to take advantage of 128 bit load."); |
| 182 | + TORCH_INTERNAL_ASSERT(num_real_classes >= ILP, "Too few classes."); |
| 183 | + |
| 184 | + int64_t num_classes = cls_output.size(-1); |
| 185 | + int64_t num_examples = cls_output.numel() / num_classes; |
| 186 | + at::Tensor loss = at::zeros({}, cls_output.options().dtype(at::kFloat)); |
| 187 | + |
| 188 | + // Compute the incompelete gradient during fprop since most of the heavy |
| 189 | + // functions of bprop are the same as fprop, thus trade memory for compute |
| 190 | + // helps with focal loss. |
| 191 | + at::Tensor partial_grad = at::empty_like(cls_output); |
| 192 | + |
| 193 | + // The grid contains 2 CTA per SM, each CTA loop on input with stride till the |
| 194 | + // last item. |
| 195 | + cudaDeviceProp props; |
| 196 | + cudaGetDeviceProperties(&props, at::cuda::current_device()); |
| 197 | + dim3 block(512); |
| 198 | + dim3 grid(2 * props.multiProcessorCount); |
| 199 | + |
| 200 | + // Specialize on label smoothing or not to reduce redundant operations |
| 201 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 202 | + if (smoothing_factor == 0.0f) { |
| 203 | + AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
| 204 | + cls_output.scalar_type(), "focal_loss_fprop", [&] { |
| 205 | + using accscalar_t = at::acc_type<scalar_t, true>; |
| 206 | + using labelscalar_t = int64_t; |
| 207 | + using outscalar_t = float; |
| 208 | + const int ILP = sizeof(uint4) / sizeof(scalar_t); |
| 209 | + focal_loss_forward_cuda_kernel<false, ILP, scalar_t, labelscalar_t, |
| 210 | + accscalar_t, outscalar_t> |
| 211 | + <<<grid, block, block.x * sizeof(accscalar_t), stream>>>( |
| 212 | + loss.data_ptr<outscalar_t>(), |
| 213 | + partial_grad.data_ptr<scalar_t>(), |
| 214 | + cls_output.data_ptr<scalar_t>(), |
| 215 | + cls_targets_at_level.data_ptr<labelscalar_t>(), |
| 216 | + num_positives_sum.data_ptr<float>(), num_examples, |
| 217 | + num_classes, num_real_classes, alpha, gamma, |
| 218 | + smoothing_factor); |
| 219 | + }); |
| 220 | + } else { |
| 221 | + AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
| 222 | + cls_output.scalar_type(), "focal_loss_fprop", [&] { |
| 223 | + using accscalar_t = at::acc_type<scalar_t, true>; |
| 224 | + using labelscalar_t = int64_t; |
| 225 | + using outscalar_t = float; |
| 226 | + const int ILP = sizeof(uint4) / sizeof(scalar_t); |
| 227 | + focal_loss_forward_cuda_kernel<true, ILP, scalar_t, labelscalar_t, |
| 228 | + accscalar_t, outscalar_t> |
| 229 | + <<<grid, block, block.x * sizeof(accscalar_t), stream>>>( |
| 230 | + loss.data_ptr<outscalar_t>(), |
| 231 | + partial_grad.data_ptr<scalar_t>(), |
| 232 | + cls_output.data_ptr<scalar_t>(), |
| 233 | + cls_targets_at_level.data_ptr<labelscalar_t>(), |
| 234 | + num_positives_sum.data_ptr<float>(), num_examples, |
| 235 | + num_classes, num_real_classes, alpha, gamma, |
| 236 | + smoothing_factor); |
| 237 | + }); |
| 238 | + } |
| 239 | + |
| 240 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 241 | + return {loss, partial_grad}; |
| 242 | +} |
| 243 | + |
| 244 | +at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output, |
| 245 | + const at::Tensor &partial_grad, |
| 246 | + const at::Tensor &num_positives_sum) { |
| 247 | + // Each thread process ILP elements |
| 248 | + const int ILP = sizeof(uint4) / partial_grad.element_size(); |
| 249 | + dim3 block(512); |
| 250 | + dim3 grid((partial_grad.numel() + block.x * ILP - 1) / (block.x * ILP)); |
| 251 | + |
| 252 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 253 | + AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
| 254 | + partial_grad.scalar_type(), "focal_loss_bprop", [&] { |
| 255 | + using accscalar_t = at::acc_type<scalar_t, true>; |
| 256 | + using outscalar_t = float; |
| 257 | + const int ILP = sizeof(uint4) / sizeof(scalar_t); |
| 258 | + focal_loss_backward_cuda_kernel<ILP, scalar_t, accscalar_t, outscalar_t> |
| 259 | + <<<grid, block, 0, stream>>>(partial_grad.data_ptr<scalar_t>(), |
| 260 | + grad_output.data_ptr<outscalar_t>(), |
| 261 | + num_positives_sum.data_ptr<float>(), |
| 262 | + partial_grad.numel()); |
| 263 | + }); |
| 264 | + |
| 265 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 266 | + return partial_grad; |
| 267 | +} |
0 commit comments