@@ -321,10 +321,9 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(int chunk_size, at::
321321 per_tensor ? output_per_tensor.data_ptr <float >() : nullptr , per_tensor,
322322 max_chunks_per_tensor);
323323 } else {
324- multi_tensor_apply<1 >(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
325- L2NormFunctor<scalar_t_0, int32_t >(), output.data_ptr <float >(),
326- per_tensor ? output_per_tensor.data_ptr <float >() : nullptr , per_tensor,
327- max_chunks_per_tensor);
324+ multi_tensor_apply<1 >(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<scalar_t_0, int32_t >(),
325+ output.data_ptr <float >(), per_tensor ? output_per_tensor.data_ptr <float >() : nullptr ,
326+ per_tensor, max_chunks_per_tensor);
328327 })
329328
330329 AT_CUDA_CHECK (cudaGetLastError ());
@@ -428,16 +427,17 @@ void multi_tensor_norm_out_cuda(int chunk_size, at::Tensor noop_flag, std::vecto
428427 output_per_tensor = at::zeros ({ntensors * max_chunks_per_tensor}, float_options);
429428
430429 if (norm_type == 0 ) {
431- DISPATCH_FLOAT_AND_HALF (tensor_lists[0 ][0 ].scalar_type (), 0 , " multi_tensor_maxnorm_cuda" ,
432- if (requires_64bit_indexing) {
433- multi_tensor_apply<1 >((int64_t )BLOCK_SIZE, (int64_t )chunk_size, noop_flag, tensor_lists,
434- MaxNormFunctor<scalar_t_0, int64_t >(), output.data_ptr <float >(),
435- output_per_tensor.data_ptr <float >(), true , max_chunks_per_tensor);
436- } else {
437- multi_tensor_apply<1 >(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
438- MaxNormFunctor<scalar_t_0, int32_t >(), output.data_ptr <float >(),
439- output_per_tensor.data_ptr <float >(), true , max_chunks_per_tensor);
440- })
430+ DISPATCH_FLOAT_AND_HALF (
431+ tensor_lists[0 ][0 ].scalar_type (), 0 , " multi_tensor_maxnorm_cuda" ,
432+ if (requires_64bit_indexing) {
433+ multi_tensor_apply<1 >((int64_t )BLOCK_SIZE, (int64_t )chunk_size, noop_flag, tensor_lists,
434+ MaxNormFunctor<scalar_t_0, int64_t >(), output.data_ptr <float >(),
435+ output_per_tensor.data_ptr <float >(), true , max_chunks_per_tensor);
436+ } else {
437+ multi_tensor_apply<1 >(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, MaxNormFunctor<scalar_t_0, int32_t >(),
438+ output.data_ptr <float >(), output_per_tensor.data_ptr <float >(), true ,
439+ max_chunks_per_tensor);
440+ })
441441 } else {
442442 DISPATCH_FLOAT_HALF_AND_BFLOAT (
443443 tensor_lists[0 ][0 ].scalar_type (), 0 , " multi_tensor_l2norm_cuda" ,
@@ -446,9 +446,9 @@ void multi_tensor_norm_out_cuda(int chunk_size, at::Tensor noop_flag, std::vecto
446446 L2NormFunctor<scalar_t_0, int64_t >(), output.data_ptr <float >(),
447447 output_per_tensor.data_ptr <float >(), true , max_chunks_per_tensor);
448448 } else {
449- multi_tensor_apply<1 >(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
450- L2NormFunctor<scalar_t_0, int32_t >(), output .data_ptr <float >(),
451- output_per_tensor. data_ptr < float >(), true , max_chunks_per_tensor);
449+ multi_tensor_apply<1 >(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<scalar_t_0, int32_t >(),
450+ output. data_ptr < float >(), output_per_tensor .data_ptr <float >(), true ,
451+ max_chunks_per_tensor);
452452 })
453453 }
454454 AT_CUDA_CHECK (cudaGetLastError ());
0 commit comments