Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 28f8539

Browse files
crcrparalpha0422
andauthored
Add CUDA Focal Loss Implementation (#1337)
Take-over of #1097 * Add fast CUDA focal loss implementation. * Enable fast math for CUDA focal loss. * Correct typo. * replace deprecated macros * Add fast CUDA focal loss implementation. * Enable fast math for CUDA focal loss. * Correct typo. * replace deprecated macros * TORCH_CUDA_CHECK -> AT_CUDA_CHECK The former is defined in torch/csrc/profiler/cuda.cpp so it's not available usually. The latter however is defined in ATen/cuda/Exceptions.h as an alias of C10_CUDA_CHECK. * add test * clean up * guard for torchvision Co-authored-by: Wil Kong <[email protected]>
1 parent feae385 commit 28f8539

6 files changed

Lines changed: 492 additions & 0 deletions

File tree

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#include <torch/torch.h>
2+
3+
#include <vector>
4+
#include <cstdint>
5+
6+
// CUDA forward declarations
7+
8+
std::vector<at::Tensor> focal_loss_forward_cuda(
9+
const at::Tensor &cls_output,
10+
const at::Tensor &cls_targets_at_level,
11+
const at::Tensor &num_positives_sum,
12+
const int64_t num_real_classes,
13+
const float alpha,
14+
const float gamma,
15+
const float smoothing_factor);
16+
17+
at::Tensor focal_loss_backward_cuda(
18+
const at::Tensor &grad_output,
19+
const at::Tensor &partial_grad,
20+
const at::Tensor &num_positives_sum);
21+
22+
// C++ interface
23+
24+
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
25+
#define CHECK_CONTIGUOUS(x) \
26+
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
27+
#define CHECK_INPUT(x) \
28+
CHECK_CUDA(x); \
29+
CHECK_CONTIGUOUS(x)
30+
31+
std::vector<at::Tensor> focal_loss_forward(
32+
const at::Tensor &cls_output,
33+
const at::Tensor &cls_targets_at_level,
34+
const at::Tensor &num_positives_sum,
35+
const int64_t num_real_classes,
36+
const float alpha,
37+
const float gamma,
38+
const float smoothing_factor
39+
) {
40+
CHECK_INPUT(cls_output);
41+
CHECK_INPUT(cls_targets_at_level);
42+
CHECK_INPUT(num_positives_sum);
43+
44+
return focal_loss_forward_cuda(
45+
cls_output,
46+
cls_targets_at_level,
47+
num_positives_sum,
48+
num_real_classes,
49+
alpha,
50+
gamma,
51+
smoothing_factor);
52+
}
53+
54+
at::Tensor focal_loss_backward(
55+
const at::Tensor &grad_output,
56+
const at::Tensor &partial_grad,
57+
const at::Tensor &num_positives_sum
58+
) {
59+
CHECK_INPUT(grad_output);
60+
CHECK_INPUT(partial_grad);
61+
62+
return focal_loss_backward_cuda(grad_output, partial_grad, num_positives_sum);
63+
}
64+
65+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
66+
m.def("forward", &focal_loss_forward,
67+
"Focal loss calculation forward (CUDA)");
68+
m.def("backward", &focal_loss_backward,
69+
"Focal loss calculation backward (CUDA)");
70+
}
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/AccumulateType.h>
3+
#include <ATen/cuda/CUDAContext.h>
4+
5+
6+
#define ASSERT_UINT4_ALIGNED(PTR) \
7+
TORCH_INTERNAL_ASSERT(is_aligned<uint4>(PTR), "Tensor " #PTR " is not uint4 aligned")
8+
9+
template <class T> bool is_aligned(const void *ptr) noexcept {
10+
auto iptr = reinterpret_cast<std::uintptr_t>(ptr);
11+
return !(iptr % alignof(T));
12+
}
13+
14+
template <bool SMOOTHING, int ILP, typename scalar_t, typename labelscalar_t,
15+
typename accscalar_t, typename outscalar_t>
16+
__global__ void focal_loss_forward_cuda_kernel(
17+
outscalar_t *loss, scalar_t *partial_grad,
18+
const scalar_t *__restrict__ cls_output,
19+
const labelscalar_t *__restrict__ cls_targets_at_level,
20+
const float *__restrict__ num_positives_sum, const int64_t num_examples,
21+
const int64_t num_classes, const int64_t num_real_classes,
22+
const float alpha, const float gamma, const float smoothing_factor) {
23+
extern __shared__ unsigned char shm[];
24+
accscalar_t *loss_shm = reinterpret_cast<accscalar_t *>(shm);
25+
loss_shm[threadIdx.x] = 0;
26+
accscalar_t loss_acc = 0;
27+
28+
accscalar_t one = accscalar_t(1.0);
29+
accscalar_t K = accscalar_t(2.0);
30+
accscalar_t normalizer = one / static_cast<accscalar_t>(num_positives_sum[0]);
31+
accscalar_t nn_norm, np_norm, pn_norm, pp_norm;
32+
33+
// *_norm is used for label smoothing only
34+
if (SMOOTHING) {
35+
nn_norm = one - smoothing_factor / K;
36+
np_norm = smoothing_factor / K;
37+
pn_norm = smoothing_factor - smoothing_factor / K;
38+
pp_norm = one - smoothing_factor + smoothing_factor / K;
39+
}
40+
41+
uint4 p_vec, grad_vec;
42+
43+
// Accumulate loss on each thread
44+
for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x) * ILP;
45+
i < num_examples * num_classes; i += gridDim.x * blockDim.x * ILP) {
46+
int64_t idy = i / num_classes;
47+
labelscalar_t y = cls_targets_at_level[idy];
48+
int64_t base_yid = i % num_classes;
49+
50+
int64_t pos_idx = idy * num_classes + y;
51+
p_vec = *(uint4 *)&cls_output[i];
52+
53+
// Skip ignored matches
54+
if (y == -2) {
55+
#pragma unroll
56+
for (int j = 0; j < ILP; j++) {
57+
*((scalar_t *)(&grad_vec) + j) = 0;
58+
}
59+
*(uint4 *)&partial_grad[i] = grad_vec;
60+
continue;
61+
}
62+
63+
#pragma unroll
64+
for (int j = 0; j < ILP; j++) {
65+
// Skip the pad classes
66+
if (base_yid + j >= num_real_classes) {
67+
*((scalar_t *)(&grad_vec) + j) = 0;
68+
continue;
69+
}
70+
71+
accscalar_t p = static_cast<accscalar_t>(*((scalar_t *)(&p_vec) + j));
72+
accscalar_t exp_np = ::exp(-p);
73+
accscalar_t exp_pp = ::exp(p);
74+
accscalar_t sigma = one / (one + exp_np);
75+
accscalar_t logee = (p >= 0) ? exp_np : exp_pp;
76+
accscalar_t addee = (p >= 0) ? 0 : -p;
77+
accscalar_t off_a = addee + ::log(one + logee);
78+
79+
// Negative matches
80+
accscalar_t base = SMOOTHING ? nn_norm * p : p;
81+
accscalar_t off_b = (SMOOTHING ? np_norm : 0) - sigma;
82+
accscalar_t coeff_f1 = one - alpha;
83+
accscalar_t coeff_f2 = sigma;
84+
accscalar_t coeff_b1 = gamma;
85+
accscalar_t coeff_b2 = one - sigma;
86+
87+
// Positive matches
88+
if (y >= 0 && (i + j == pos_idx)) {
89+
base = SMOOTHING ? pn_norm * p : 0;
90+
off_b = (SMOOTHING ? pp_norm : one) - sigma;
91+
coeff_f1 = alpha;
92+
coeff_f2 = one - sigma;
93+
coeff_b1 = -gamma;
94+
coeff_b2 = sigma;
95+
}
96+
97+
accscalar_t coeff_f = coeff_f1 * ::pow(coeff_f2, gamma);
98+
accscalar_t coeff_b = coeff_b1 * coeff_b2;
99+
100+
accscalar_t loss_t = coeff_f * (base + off_a);
101+
accscalar_t grad = coeff_f * (coeff_b * (base + off_a) - off_b);
102+
103+
// Delay the normalize of partial gradient by num_positives_sum to back
104+
// propagation because scalar_t reduces precision. Focal loss is very
105+
// sensitive to the small gradient. No worry on overflow here since
106+
// gradient has relative smaller range than input.
107+
loss_acc += loss_t;
108+
*((scalar_t *)(&grad_vec) + j) = static_cast<scalar_t>(grad);
109+
}
110+
111+
// This can't ensure to generate stg.128 and may be two stg.64.
112+
*(uint4 *)&partial_grad[i] = grad_vec;
113+
}
114+
loss_shm[threadIdx.x] = loss_acc;
115+
116+
// Intra-CTA reduction
117+
__syncthreads();
118+
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
119+
if (threadIdx.x < s) {
120+
loss_shm[threadIdx.x] += loss_shm[threadIdx.x + s];
121+
}
122+
__syncthreads();
123+
}
124+
125+
// Inter-CTA reduction
126+
if (threadIdx.x == 0) {
127+
loss_acc = loss_shm[0] * normalizer;
128+
atomicAdd(loss, loss_acc);
129+
}
130+
}
131+
132+
template <int ILP, typename scalar_t, typename accscalar_t,
133+
typename outscalar_t>
134+
__global__ void focal_loss_backward_cuda_kernel(
135+
scalar_t *partial_grad, const outscalar_t *__restrict__ grad_output,
136+
const float *__restrict__ num_positives_sum, const uint64_t numel) {
137+
int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * ILP;
138+
139+
accscalar_t normalizer = static_cast<accscalar_t>(grad_output[0]) /
140+
static_cast<accscalar_t>(num_positives_sum[0]);
141+
142+
// The input is enforced to pad to use vector load, thus there's no need to
143+
// check whether the last element of ILP can out of bound.
144+
if (idx >= numel)
145+
return;
146+
147+
uint4 grad_vec;
148+
grad_vec = *(uint4 *)&partial_grad[idx];
149+
#pragma unroll(ILP)
150+
for (int i = 0; i < ILP; i++) {
151+
auto grad = static_cast<accscalar_t>(*((scalar_t *)(&grad_vec) + i));
152+
grad *= normalizer;
153+
*((scalar_t *)(&grad_vec) + i) = static_cast<scalar_t>(grad);
154+
}
155+
*(uint4 *)&partial_grad[idx] = grad_vec;
156+
}
157+
158+
std::vector<at::Tensor> focal_loss_forward_cuda(
159+
const at::Tensor &cls_output, const at::Tensor &cls_targets_at_level,
160+
const at::Tensor &num_positives_sum, const int64_t num_real_classes,
161+
const float alpha, const float gamma, const float smoothing_factor) {
162+
// Checks required for correctness
163+
TORCH_INTERNAL_ASSERT(cls_output.size(-1) >= num_real_classes,
164+
"Incorrect number of real classes.");
165+
TORCH_INTERNAL_ASSERT(cls_targets_at_level.scalar_type() == at::kLong,
166+
"Invalid label type.");
167+
TORCH_INTERNAL_ASSERT(
168+
(num_positives_sum.numel() == 1) &&
169+
(num_positives_sum.scalar_type() == at::kFloat),
170+
"Expect num_positives_sum to be a float32 tensor with only one element.");
171+
TORCH_INTERNAL_ASSERT(cls_output.dim() == cls_targets_at_level.dim() + 1,
172+
"Mis-matched dimensions between class output and label.");
173+
for (int64_t i = 0; i < cls_targets_at_level.dim(); i++)
174+
TORCH_INTERNAL_ASSERT(cls_output.size(i) == cls_targets_at_level.size(i),
175+
"Mis-matched shape between class output and label.");
176+
177+
// Checks required for better performance
178+
const int ILP = sizeof(uint4) / cls_output.element_size();
179+
ASSERT_UINT4_ALIGNED(cls_output.data_ptr());
180+
TORCH_INTERNAL_ASSERT(cls_output.size(-1) % ILP == 0,
181+
"Pad number of classes first to take advantage of 128 bit load.");
182+
TORCH_INTERNAL_ASSERT(num_real_classes >= ILP, "Too few classes.");
183+
184+
int64_t num_classes = cls_output.size(-1);
185+
int64_t num_examples = cls_output.numel() / num_classes;
186+
at::Tensor loss = at::zeros({}, cls_output.options().dtype(at::kFloat));
187+
188+
// Compute the incompelete gradient during fprop since most of the heavy
189+
// functions of bprop are the same as fprop, thus trade memory for compute
190+
// helps with focal loss.
191+
at::Tensor partial_grad = at::empty_like(cls_output);
192+
193+
// The grid contains 2 CTA per SM, each CTA loop on input with stride till the
194+
// last item.
195+
cudaDeviceProp props;
196+
cudaGetDeviceProperties(&props, at::cuda::current_device());
197+
dim3 block(512);
198+
dim3 grid(2 * props.multiProcessorCount);
199+
200+
// Specialize on label smoothing or not to reduce redundant operations
201+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
202+
if (smoothing_factor == 0.0f) {
203+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
204+
cls_output.scalar_type(), "focal_loss_fprop", [&] {
205+
using accscalar_t = at::acc_type<scalar_t, true>;
206+
using labelscalar_t = int64_t;
207+
using outscalar_t = float;
208+
const int ILP = sizeof(uint4) / sizeof(scalar_t);
209+
focal_loss_forward_cuda_kernel<false, ILP, scalar_t, labelscalar_t,
210+
accscalar_t, outscalar_t>
211+
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
212+
loss.data_ptr<outscalar_t>(),
213+
partial_grad.data_ptr<scalar_t>(),
214+
cls_output.data_ptr<scalar_t>(),
215+
cls_targets_at_level.data_ptr<labelscalar_t>(),
216+
num_positives_sum.data_ptr<float>(), num_examples,
217+
num_classes, num_real_classes, alpha, gamma,
218+
smoothing_factor);
219+
});
220+
} else {
221+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
222+
cls_output.scalar_type(), "focal_loss_fprop", [&] {
223+
using accscalar_t = at::acc_type<scalar_t, true>;
224+
using labelscalar_t = int64_t;
225+
using outscalar_t = float;
226+
const int ILP = sizeof(uint4) / sizeof(scalar_t);
227+
focal_loss_forward_cuda_kernel<true, ILP, scalar_t, labelscalar_t,
228+
accscalar_t, outscalar_t>
229+
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
230+
loss.data_ptr<outscalar_t>(),
231+
partial_grad.data_ptr<scalar_t>(),
232+
cls_output.data_ptr<scalar_t>(),
233+
cls_targets_at_level.data_ptr<labelscalar_t>(),
234+
num_positives_sum.data_ptr<float>(), num_examples,
235+
num_classes, num_real_classes, alpha, gamma,
236+
smoothing_factor);
237+
});
238+
}
239+
240+
AT_CUDA_CHECK(cudaGetLastError());
241+
return {loss, partial_grad};
242+
}
243+
244+
at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output,
245+
const at::Tensor &partial_grad,
246+
const at::Tensor &num_positives_sum) {
247+
// Each thread process ILP elements
248+
const int ILP = sizeof(uint4) / partial_grad.element_size();
249+
dim3 block(512);
250+
dim3 grid((partial_grad.numel() + block.x * ILP - 1) / (block.x * ILP));
251+
252+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
253+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
254+
partial_grad.scalar_type(), "focal_loss_bprop", [&] {
255+
using accscalar_t = at::acc_type<scalar_t, true>;
256+
using outscalar_t = float;
257+
const int ILP = sizeof(uint4) / sizeof(scalar_t);
258+
focal_loss_backward_cuda_kernel<ILP, scalar_t, accscalar_t, outscalar_t>
259+
<<<grid, block, 0, stream>>>(partial_grad.data_ptr<scalar_t>(),
260+
grad_output.data_ptr<outscalar_t>(),
261+
num_positives_sum.data_ptr<float>(),
262+
partial_grad.numel());
263+
});
264+
265+
AT_CUDA_CHECK(cudaGetLastError());
266+
return partial_grad;
267+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
try:
2+
import torch
3+
import focal_loss_cuda
4+
from .focal_loss import focal_loss
5+
del torch
6+
del focal_loss_cuda
7+
del focal_loss
8+
except ImportError as err:
9+
print("apex was installed without --focal_loss flag, apex.contrib.focal_loss is not available")

0 commit comments

Comments
 (0)