Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9c50239

Browse files
crcrparminitu
andauthored
[contrib] Improve apex.contrib.focal_loss for B200 (#1888)
Co-authored-by: Jaemin Choi <[email protected]>
1 parent 312acb4 commit 9c50239

1 file changed

Lines changed: 26 additions & 20 deletions

File tree

apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
#include <ATen/AccumulateType.h>
33
#include <ATen/cuda/CUDAContext.h>
44

5+
// Use 128-bit vectorization
6+
typedef uint4 vector_t;
57

6-
#define ASSERT_UINT4_ALIGNED(PTR) \
7-
TORCH_INTERNAL_ASSERT(is_aligned<uint4>(PTR), "Tensor " #PTR " is not uint4 aligned")
8+
#define ASSERT_ALIGNED(DTYPE, PTR) \
9+
TORCH_INTERNAL_ASSERT(is_aligned<DTYPE>(PTR), "Tensor " #PTR " is not " #DTYPE " aligned")
810

911
template <class T> bool is_aligned(const void *ptr) noexcept {
1012
auto iptr = reinterpret_cast<std::uintptr_t>(ptr);
@@ -38,7 +40,7 @@ __global__ void focal_loss_forward_cuda_kernel(
3840
pp_norm = one - smoothing_factor + smoothing_factor / K;
3941
}
4042

41-
uint4 p_vec, grad_vec;
43+
vector_t p_vec, grad_vec;
4244

4345
// Accumulate loss on each thread
4446
for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x) * ILP;
@@ -48,15 +50,15 @@ __global__ void focal_loss_forward_cuda_kernel(
4850
int64_t base_yid = i % num_classes;
4951

5052
int64_t pos_idx = idy * num_classes + y;
51-
p_vec = *(uint4 *)&cls_output[i];
53+
p_vec = *(vector_t *)&cls_output[i]; // Vectorized load
5254

5355
// Skip ignored matches
5456
if (y == -2) {
5557
#pragma unroll
5658
for (int j = 0; j < ILP; j++) {
5759
*((scalar_t *)(&grad_vec) + j) = 0;
5860
}
59-
*(uint4 *)&partial_grad[i] = grad_vec;
61+
*(vector_t *)&partial_grad[i] = grad_vec;
6062
continue;
6163
}
6264

@@ -108,8 +110,8 @@ __global__ void focal_loss_forward_cuda_kernel(
108110
*((scalar_t *)(&grad_vec) + j) = static_cast<scalar_t>(grad);
109111
}
110112

111-
// This can't ensure to generate stg.128 and may be two stg.64.
112-
*(uint4 *)&partial_grad[i] = grad_vec;
113+
// This may generate two vectorized stores instead of one
114+
*(vector_t *)&partial_grad[i] = grad_vec;
113115
}
114116
loss_shm[threadIdx.x] = loss_acc;
115117

@@ -144,15 +146,15 @@ __global__ void focal_loss_backward_cuda_kernel(
144146
if (idx >= numel)
145147
return;
146148

147-
uint4 grad_vec;
148-
grad_vec = *(uint4 *)&partial_grad[idx];
149+
vector_t grad_vec;
150+
grad_vec = *(vector_t *)&partial_grad[idx];
149151
#pragma unroll(ILP)
150152
for (int i = 0; i < ILP; i++) {
151153
auto grad = static_cast<accscalar_t>(*((scalar_t *)(&grad_vec) + i));
152154
grad *= normalizer;
153155
*((scalar_t *)(&grad_vec) + i) = static_cast<scalar_t>(grad);
154156
}
155-
*(uint4 *)&partial_grad[idx] = grad_vec;
157+
*(vector_t *)&partial_grad[idx] = grad_vec;
156158
}
157159

158160
std::vector<at::Tensor> focal_loss_forward_cuda(
@@ -175,10 +177,10 @@ std::vector<at::Tensor> focal_loss_forward_cuda(
175177
"Mis-matched shape between class output and label.");
176178

177179
// Checks required for better performance
178-
const int ILP = sizeof(uint4) / cls_output.element_size();
179-
ASSERT_UINT4_ALIGNED(cls_output.data_ptr());
180+
const int ILP = sizeof(vector_t) / cls_output.element_size();
181+
ASSERT_ALIGNED(vector_t, cls_output.data_ptr());
180182
TORCH_INTERNAL_ASSERT(cls_output.size(-1) % ILP == 0,
181-
"Pad number of classes first to take advantage of 128 bit load.");
183+
"Pad number of classes first to take advantage of vectorized load.");
182184
TORCH_INTERNAL_ASSERT(num_real_classes >= ILP, "Too few classes.");
183185

184186
int64_t num_classes = cls_output.size(-1);
@@ -190,12 +192,16 @@ std::vector<at::Tensor> focal_loss_forward_cuda(
190192
// helps with focal loss.
191193
at::Tensor partial_grad = at::empty_like(cls_output);
192194

193-
// The grid contains 2 CTA per SM, each CTA loop on input with stride till the
194-
// last item.
195+
// Set the number of CTAs per SM according to the compute capability.
196+
// Each CTA loops on input with stride till the last item.
195197
cudaDeviceProp props;
196198
cudaGetDeviceProperties(&props, at::cuda::current_device());
199+
int cta_per_sm = 2;
200+
if (props.major >= 10) {
201+
cta_per_sm = 8;
202+
}
197203
dim3 block(512);
198-
dim3 grid(2 * props.multiProcessorCount);
204+
dim3 grid(cta_per_sm * props.multiProcessorCount);
199205

200206
// Specialize on label smoothing or not to reduce redundant operations
201207
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -205,7 +211,7 @@ std::vector<at::Tensor> focal_loss_forward_cuda(
205211
using accscalar_t = at::acc_type<scalar_t, true>;
206212
using labelscalar_t = int64_t;
207213
using outscalar_t = float;
208-
const int ILP = sizeof(uint4) / sizeof(scalar_t);
214+
const int ILP = sizeof(vector_t) / sizeof(scalar_t);
209215
focal_loss_forward_cuda_kernel<false, ILP, scalar_t, labelscalar_t,
210216
accscalar_t, outscalar_t>
211217
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
@@ -223,7 +229,7 @@ std::vector<at::Tensor> focal_loss_forward_cuda(
223229
using accscalar_t = at::acc_type<scalar_t, true>;
224230
using labelscalar_t = int64_t;
225231
using outscalar_t = float;
226-
const int ILP = sizeof(uint4) / sizeof(scalar_t);
232+
const int ILP = sizeof(vector_t) / sizeof(scalar_t);
227233
focal_loss_forward_cuda_kernel<true, ILP, scalar_t, labelscalar_t,
228234
accscalar_t, outscalar_t>
229235
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
@@ -245,7 +251,7 @@ at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output,
245251
const at::Tensor &partial_grad,
246252
const at::Tensor &num_positives_sum) {
247253
// Each thread process ILP elements
248-
const int ILP = sizeof(uint4) / partial_grad.element_size();
254+
const int ILP = sizeof(vector_t) / partial_grad.element_size();
249255
dim3 block(512);
250256
dim3 grid((partial_grad.numel() + block.x * ILP - 1) / (block.x * ILP));
251257

@@ -254,7 +260,7 @@ at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output,
254260
partial_grad.scalar_type(), "focal_loss_bprop", [&] {
255261
using accscalar_t = at::acc_type<scalar_t, true>;
256262
using outscalar_t = float;
257-
const int ILP = sizeof(uint4) / sizeof(scalar_t);
263+
const int ILP = sizeof(vector_t) / sizeof(scalar_t);
258264
focal_loss_backward_cuda_kernel<ILP, scalar_t, accscalar_t, outscalar_t>
259265
<<<grid, block, 0, stream>>>(partial_grad.data_ptr<scalar_t>(),
260266
grad_output.data_ptr<outscalar_t>(),

0 commit comments

Comments
 (0)