22#include < ATen/AccumulateType.h>
33#include < ATen/cuda/CUDAContext.h>
44
5+ // Use 128-bit vectorization
6+ typedef uint4 vector_t ;
57
6- #define ASSERT_UINT4_ALIGNED ( PTR ) \
7- TORCH_INTERNAL_ASSERT (is_aligned<uint4 >(PTR), "Tensor " #PTR " is not uint4 aligned" )
8+ #define ASSERT_ALIGNED ( DTYPE, PTR ) \
9+ TORCH_INTERNAL_ASSERT (is_aligned<DTYPE >(PTR), "Tensor " #PTR " is not " #DTYPE " aligned" )
810
911template <class T > bool is_aligned (const void *ptr) noexcept {
1012 auto iptr = reinterpret_cast <std::uintptr_t >(ptr);
@@ -38,7 +40,7 @@ __global__ void focal_loss_forward_cuda_kernel(
3840 pp_norm = one - smoothing_factor + smoothing_factor / K;
3941 }
4042
41- uint4 p_vec, grad_vec;
43+ vector_t p_vec, grad_vec;
4244
4345 // Accumulate loss on each thread
4446 for (int64_t i = (blockIdx .x * blockDim .x + threadIdx .x ) * ILP;
@@ -48,15 +50,15 @@ __global__ void focal_loss_forward_cuda_kernel(
4850 int64_t base_yid = i % num_classes;
4951
5052 int64_t pos_idx = idy * num_classes + y;
51- p_vec = *(uint4 *)&cls_output[i];
53+ p_vec = *(vector_t *)&cls_output[i]; // Vectorized load
5254
5355 // Skip ignored matches
5456 if (y == -2 ) {
5557#pragma unroll
5658 for (int j = 0 ; j < ILP; j++) {
5759 *((scalar_t *)(&grad_vec) + j) = 0 ;
5860 }
59- *(uint4 *)&partial_grad[i] = grad_vec;
61+ *(vector_t *)&partial_grad[i] = grad_vec;
6062 continue ;
6163 }
6264
@@ -108,8 +110,8 @@ __global__ void focal_loss_forward_cuda_kernel(
108110 *((scalar_t *)(&grad_vec) + j) = static_cast <scalar_t >(grad);
109111 }
110112
111- // This can't ensure to generate stg.128 and may be two stg.64.
112- *(uint4 *)&partial_grad[i] = grad_vec;
113+ // This may generate two vectorized stores instead of one
114+ *(vector_t *)&partial_grad[i] = grad_vec;
113115 }
114116 loss_shm[threadIdx .x ] = loss_acc;
115117
@@ -144,15 +146,15 @@ __global__ void focal_loss_backward_cuda_kernel(
144146 if (idx >= numel)
145147 return ;
146148
147- uint4 grad_vec;
148- grad_vec = *(uint4 *)&partial_grad[idx];
149+ vector_t grad_vec;
150+ grad_vec = *(vector_t *)&partial_grad[idx];
149151#pragma unroll(ILP)
150152 for (int i = 0 ; i < ILP; i++) {
151153 auto grad = static_cast <accscalar_t >(*((scalar_t *)(&grad_vec) + i));
152154 grad *= normalizer;
153155 *((scalar_t *)(&grad_vec) + i) = static_cast <scalar_t >(grad);
154156 }
155- *(uint4 *)&partial_grad[idx] = grad_vec;
157+ *(vector_t *)&partial_grad[idx] = grad_vec;
156158}
157159
158160std::vector<at::Tensor> focal_loss_forward_cuda (
@@ -175,10 +177,10 @@ std::vector<at::Tensor> focal_loss_forward_cuda(
175177 " Mis-matched shape between class output and label." );
176178
177179 // Checks required for better performance
178- const int ILP = sizeof (uint4 ) / cls_output.element_size ();
179- ASSERT_UINT4_ALIGNED ( cls_output.data_ptr ());
180+ const int ILP = sizeof (vector_t ) / cls_output.element_size ();
181+ ASSERT_ALIGNED ( vector_t , cls_output.data_ptr ());
180182 TORCH_INTERNAL_ASSERT (cls_output.size (-1 ) % ILP == 0 ,
181- " Pad number of classes first to take advantage of 128 bit load." );
183+ " Pad number of classes first to take advantage of vectorized load." );
182184 TORCH_INTERNAL_ASSERT (num_real_classes >= ILP, " Too few classes." );
183185
184186 int64_t num_classes = cls_output.size (-1 );
@@ -190,12 +192,16 @@ std::vector<at::Tensor> focal_loss_forward_cuda(
190192 // helps with focal loss.
191193 at::Tensor partial_grad = at::empty_like (cls_output);
192194
193- // The grid contains 2 CTA per SM, each CTA loop on input with stride till the
194- // last item.
195+ // Set the number of CTAs per SM according to the compute capability.
196+ // Each CTA loops on input with stride till the last item.
195197 cudaDeviceProp props;
196198 cudaGetDeviceProperties (&props, at::cuda::current_device ());
199+ int cta_per_sm = 2 ;
200+ if (props.major >= 10 ) {
201+ cta_per_sm = 8 ;
202+ }
197203 dim3 block (512 );
198- dim3 grid (2 * props.multiProcessorCount );
204+ dim3 grid (cta_per_sm * props.multiProcessorCount );
199205
200206 // Specialize on label smoothing or not to reduce redundant operations
201207 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
@@ -205,7 +211,7 @@ std::vector<at::Tensor> focal_loss_forward_cuda(
205211 using accscalar_t = at::acc_type<scalar_t , true >;
206212 using labelscalar_t = int64_t ;
207213 using outscalar_t = float ;
208- const int ILP = sizeof (uint4 ) / sizeof (scalar_t );
214+ const int ILP = sizeof (vector_t ) / sizeof (scalar_t );
209215 focal_loss_forward_cuda_kernel<false , ILP, scalar_t , labelscalar_t ,
210216 accscalar_t , outscalar_t >
211217 <<<grid, block, block.x * sizeof (accscalar_t ), stream>>> (
@@ -223,7 +229,7 @@ std::vector<at::Tensor> focal_loss_forward_cuda(
223229 using accscalar_t = at::acc_type<scalar_t , true >;
224230 using labelscalar_t = int64_t ;
225231 using outscalar_t = float ;
226- const int ILP = sizeof (uint4 ) / sizeof (scalar_t );
232+ const int ILP = sizeof (vector_t ) / sizeof (scalar_t );
227233 focal_loss_forward_cuda_kernel<true , ILP, scalar_t , labelscalar_t ,
228234 accscalar_t , outscalar_t >
229235 <<<grid, block, block.x * sizeof (accscalar_t ), stream>>> (
@@ -245,7 +251,7 @@ at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output,
245251 const at::Tensor &partial_grad,
246252 const at::Tensor &num_positives_sum) {
247253 // Each thread process ILP elements
248- const int ILP = sizeof (uint4 ) / partial_grad.element_size ();
254+ const int ILP = sizeof (vector_t ) / partial_grad.element_size ();
249255 dim3 block (512 );
250256 dim3 grid ((partial_grad.numel () + block.x * ILP - 1 ) / (block.x * ILP));
251257
@@ -254,7 +260,7 @@ at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output,
254260 partial_grad.scalar_type (), " focal_loss_bprop" , [&] {
255261 using accscalar_t = at::acc_type<scalar_t , true >;
256262 using outscalar_t = float ;
257- const int ILP = sizeof (uint4 ) / sizeof (scalar_t );
263+ const int ILP = sizeof (vector_t ) / sizeof (scalar_t );
258264 focal_loss_backward_cuda_kernel<ILP, scalar_t , accscalar_t , outscalar_t >
259265 <<<grid, block, 0 , stream>>> (partial_grad.data_ptr <scalar_t >(),
260266 grad_output.data_ptr <outscalar_t >(),
0 commit comments