diff --git a/csrc/fused/fused.cu b/csrc/fused/fused.cu index fb8b9f15..9aeae5ed 100644 --- a/csrc/fused/fused.cu +++ b/csrc/fused/fused.cu @@ -478,6 +478,7 @@ void quant_per_block_int8_cuda( } auto input_dtype = input.scalar_type(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { @@ -492,7 +493,7 @@ void quant_per_block_int8_cuda( dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); - QuantInt8Kernel<<>>( + QuantInt8Kernel<<>>( reinterpret_cast(input.data_ptr()), nullptr, output.data_ptr(), @@ -560,6 +561,7 @@ void quant_per_block_int8_cuda( } auto input_dtype = input.scalar_type(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { @@ -574,7 +576,7 @@ void quant_per_block_int8_cuda( dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); - QuantInt8Kernel<<>>( + QuantInt8Kernel<<>>( reinterpret_cast(input.data_ptr()), nullptr, output.data_ptr(), @@ -647,6 +649,7 @@ void quant_per_block_int8_fuse_sub_mean_cuda( auto input_dtype = input.scalar_type(); auto mean_dtype = mean.scalar_type(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); TORCH_CHECK(input_dtype == mean_dtype, "Input and mean must have the same data type"); @@ -664,7 +667,7 @@ void quant_per_block_int8_fuse_sub_mean_cuda( dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); - QuantInt8Kernel<<>>( + QuantInt8Kernel<<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast(mean.data_ptr()), output.data_ptr(), @@ -734,6 +737,7 @@ void quant_per_warp_int8_cuda( } auto input_dtype = input.scalar_type(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { @@ -749,7 +753,7 @@ void quant_per_warp_int8_cuda( dim3 block(WARP_BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); - QuantInt8Kernel<<>>( + QuantInt8Kernel<<>>( reinterpret_cast(input.data_ptr()), nullptr, output.data_ptr(), @@ -817,6 +821,7 @@ void sub_mean_cuda( auto input_dtype = input.scalar_type(); auto mean_dtype = mean.scalar_type(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); TORCH_CHECK(input_dtype == mean_dtype, "Input and mean must have the same data type"); @@ -834,7 +839,7 @@ void sub_mean_cuda( dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); - SubMeanKernel<<>>( + SubMeanKernel<<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast(mean.data_ptr()), reinterpret_cast(output.data_ptr()), @@ -900,6 +905,7 @@ void transpose_pad_permute_cuda( auto input_dtype = input.scalar_type(); auto output_dtype = output.scalar_type(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); TORCH_CHECK(input_dtype == output_dtype, "Input and output must have the same data type"); @@ -911,7 +917,7 @@ void transpose_pad_permute_cuda( dim3 block(CTA_SIZE * (HEAD_DIM / 8)); - TransposePadPermuteKernel<<>>( + TransposePadPermuteKernel<<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast(output.data_ptr()), num_tokens, @@ -982,9 +988,10 @@ void scale_fuse_quant_cuda( dim3 block(CTA_SIZE); auto input_dtype = input.scalar_type(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { - MeanScaleKernel<64, false, c_type><<>>( + MeanScaleKernel<64, false, c_type><<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast(output.data_ptr()), nullptr, @@ -1065,9 +1072,10 @@ void mean_scale_fuse_quant_cuda( dim3 block(CTA_SIZE); auto input_dtype = input.scalar_type(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { - MeanScaleKernel<64, true, c_type><<>>( + MeanScaleKernel<64, true, c_type><<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast(output.data_ptr()), reinterpret_cast(mean.data_ptr()), diff --git a/csrc/qattn/qk_int_sv_f16_cuda_sm80.cu b/csrc/qattn/qk_int_sv_f16_cuda_sm80.cu index bdedaf66..5fc8c4c4 100644 --- a/csrc/qattn/qk_int_sv_f16_cuda_sm80.cu +++ b/csrc/qattn/qk_int_sv_f16_cuda_sm80.cu @@ -17,6 +17,7 @@ #include "../utils.cuh" #include #include +#include #include #include "../cp_async.cuh" @@ -718,6 +719,7 @@ torch::Tensor qk_int8_sv_f16_accum_f32_attn(torch::Tensor query, int stride_bz_v = value.stride(0); int stride_bz_o = output.stride(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int qo_len, kv_len, num_qo_heads, num_kv_heads; int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; int stride_h_q, stride_h_k, stride_h_v, stride_h_o; @@ -819,7 +821,7 @@ torch::Tensor qk_int8_sv_f16_accum_f32_attn(torch::Tensor query, dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); - kernel_func<<>>( + kernel_func<<>>( query.data_ptr(), key.data_ptr(), reinterpret_cast(value.data_ptr()), @@ -892,6 +894,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_attn(torch::Tensor query, int stride_bz_v = value.stride(0); int stride_bz_o = output.stride(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int qo_len, kv_len, num_qo_heads, num_kv_heads; int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; int stride_h_q, stride_h_k, stride_h_v, stride_h_o; @@ -994,7 +997,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_attn(torch::Tensor query, dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); - kernel_func<<>>( + kernel_func<<>>( query.data_ptr(), key.data_ptr(), reinterpret_cast(value.data_ptr()), @@ -1067,6 +1070,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_attn_inst_buf(torch::Tensor query, int stride_bz_v = value.stride(0); int stride_bz_o = output.stride(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int qo_len, kv_len, num_qo_heads, num_kv_heads; int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; int stride_h_q, stride_h_k, stride_h_v, stride_h_o; @@ -1169,7 +1173,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_attn_inst_buf(torch::Tensor query, dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); - kernel_func<<>>( + kernel_func<<>>( query.data_ptr(), key.data_ptr(), reinterpret_cast(value.data_ptr()), @@ -1246,6 +1250,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(torch::Tensor query, int stride_bz_v = value.stride(0); int stride_bz_o = output.stride(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int qo_len, kv_len, num_qo_heads, num_kv_heads; int stride_seq_q, stride_seq_k, stride_seq_v, stride_seq_o; int stride_h_q, stride_h_k, stride_h_v, stride_h_o; @@ -1353,7 +1358,7 @@ torch::Tensor qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(torch::Tensor query, dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); - kernel_func<<>>( + kernel_func<<>>( query.data_ptr(), key.data_ptr(), reinterpret_cast(value.data_ptr()), diff --git a/csrc/qattn/qk_int_sv_f8_cuda_sm89.cu b/csrc/qattn/qk_int_sv_f8_cuda_sm89.cu index 3b1b4305..0552ecdc 100644 --- a/csrc/qattn/qk_int_sv_f8_cuda_sm89.cu +++ b/csrc/qattn/qk_int_sv_f8_cuda_sm89.cu @@ -17,6 +17,7 @@ #include "../utils.cuh" #include #include +#include #include #include "../cp_async.cuh" @@ -733,6 +734,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn(torch::Tensor query, int stride_bz_v = value.stride(0); int stride_bz_o = output.stride(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int qo_len, kv_len, num_qo_heads, num_kv_heads; int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; @@ -836,7 +838,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn(torch::Tensor query, dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); - kernel_func<<>>( + kernel_func<<>>( query.data_ptr(), key.data_ptr(), reinterpret_cast(value.data_ptr()), @@ -911,6 +913,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(torch::Tensor query, int stride_bz_v = value.stride(0); int stride_bz_o = output.stride(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int qo_len, kv_len, num_qo_heads, num_kv_heads; int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; @@ -1014,7 +1017,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(torch::Tensor query, dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); - kernel_func<<>>( + kernel_func<<>>( query.data_ptr(), key.data_ptr(), reinterpret_cast(value.data_ptr()), @@ -1099,6 +1102,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(torch::Tenso int stride_bz_v = value.stride(0); int stride_bz_o = output.stride(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int qo_len, kv_len, num_qo_heads, num_kv_heads; int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; @@ -1205,7 +1209,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(torch::Tenso dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); - kernel_func<<>>( + kernel_func<<>>( query.data_ptr(), key.data_ptr(), reinterpret_cast(value.data_ptr()), @@ -1285,6 +1289,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(torch::Tensor query, int stride_bz_v = value.stride(0); int stride_bz_o = output.stride(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int qo_len, kv_len, num_qo_heads, num_kv_heads; int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; @@ -1391,7 +1396,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(torch::Tensor query, dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); - kernel_func<<>>( + kernel_func<<>>( query.data_ptr(), key.data_ptr(), reinterpret_cast(value.data_ptr()), @@ -1471,6 +1476,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(torch::Tensor q int stride_bz_v = value.stride(0); int stride_bz_o = output.stride(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int qo_len, kv_len, num_qo_heads, num_kv_heads; int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; @@ -1577,7 +1583,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(torch::Tensor q dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); dim3 block(32, (CTA_Q / WARP_Q) * (CTA_K / WARP_K)); - kernel_func<<>>( + kernel_func<<>>( query.data_ptr(), key.data_ptr(), reinterpret_cast(value.data_ptr()), diff --git a/csrc/qattn/qk_int_sv_f8_cuda_sm90.cu b/csrc/qattn/qk_int_sv_f8_cuda_sm90.cu index e9e5ccf5..1775bff7 100644 --- a/csrc/qattn/qk_int_sv_f8_cuda_sm90.cu +++ b/csrc/qattn/qk_int_sv_f8_cuda_sm90.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "../wgmma.cuh" @@ -614,6 +615,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf( int stride_bz_v = value.stride(0); int stride_bz_o = output.stride(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads; int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; @@ -717,7 +719,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf( cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize); dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); - kernel<<>>( + kernel<<>>( tma_map_Q, tma_map_K, tma_map_V, @@ -790,6 +792,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( int stride_bz_v = value.stride(0); int stride_bz_o = output.stride(0); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int qo_len, kv_len, padded_kv_len, num_qo_heads, num_kv_heads; int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; @@ -895,7 +898,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( cudaFuncAttributeMaxDynamicSharedMemorySize, sMemSize); dim3 grid(div_ceil(qo_len, CTA_Q), num_qo_heads, batch_size); - kernel<<>>( + kernel<<>>( tma_map_Q, tma_map_K, tma_map_V, @@ -913,4 +916,4 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( }); return lse; -} \ No newline at end of file +}