2525namespace {
2626
2727template <typename scalar_t >
28- __global__ void fused_rope_forward (int sq, int b, int np, int hn, int hn2,
28+ __global__ void fused_rope_forward (int h, int d, int d2, int stride_s,
29+ int stride_b, int stride_h, int stride_d,
30+ int o_stride_s, int o_stride_b,
31+ int o_stride_h, int o_stride_d,
2932 const scalar_t * src, const scalar_t * cos,
3033 const scalar_t * sin, scalar_t * dst) {
31- int sq_id = blockIdx.x , b_id = blockIdx.y ;
32- int offset_block = sq_id * b * np * hn + b_id * np * hn;
34+ int s_id = blockIdx.x , b_id = blockIdx.y ;
35+ int offset_block = s_id * stride_s + b_id * stride_b;
36+ int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
3337#pragma unroll
34- for (int hn_id = threadIdx.x ; hn_id < hn2; hn_id += blockDim.x ) {
35- scalar_t v_cos = cos[sq_id * hn2 + hn_id ];
36- scalar_t v_sin = sin[sq_id * hn2 + hn_id ];
38+ for (int d_id = threadIdx.x ; d_id < d2; d_id += blockDim.x ) {
39+ scalar_t v_cos = cos[s_id * d2 + d_id ];
40+ scalar_t v_sin = sin[s_id * d2 + d_id ];
3741#pragma unroll
38- for (int head_id = threadIdx.y ; head_id < np; head_id += blockDim.y ) {
39- int offset_src_dst = offset_block + head_id * hn + hn_id;
40- scalar_t v_src = src[offset_src_dst];
41- scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2)
42- ? -src[offset_src_dst + hn2 / 2 ]
43- : src[offset_src_dst + hn2 / 2 - hn2];
44- dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin;
42+ for (int h_id = threadIdx.y ; h_id < h; h_id += blockDim.y ) {
43+ int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
44+ int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
45+ scalar_t v_src = src[offset_src];
46+ scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
47+ ? -src[offset_src + (d2 / 2 ) * stride_d]
48+ : src[offset_src + (d2 / 2 - d2) * stride_d];
49+ dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
4550 }
4651 }
4752
4853 // copy the rest
49- if (hn > hn2 ) {
54+ if (d > d2 ) {
5055#pragma unroll
51- for (int head_id = threadIdx.y ; head_id < np; head_id += blockDim.y ) {
52- int offset_head = offset_block + head_id * hn;
56+ for (int h_id = threadIdx.y ; h_id < h; h_id += blockDim.y ) {
57+ int offset_head = offset_block + h_id * stride_h;
58+ int offset_head_dst = offset_block_dst + h_id * o_stride_h;
5359#pragma unroll
54- for (int hn_id = hn2 + threadIdx.x ; hn_id < hn; hn_id += blockDim.x ) {
55- dst[offset_head + hn_id] = src[offset_head + hn_id];
60+ for (int d_id = d2 + threadIdx.x ; d_id < d; d_id += blockDim.x ) {
61+ dst[offset_head_dst + d_id * o_stride_d] =
62+ src[offset_head + d_id * stride_d];
5663 }
5764 }
5865 }
5966}
6067
6168template <typename scalar_t >
62- __global__ void fused_rope_backward (int sq, int b, int np, int hn, int hn2,
69+ __global__ void fused_rope_backward (int h, int d, int d2, int stride_s,
70+ int stride_b, int stride_h, int stride_d,
71+ int o_stride_s, int o_stride_b,
72+ int o_stride_h, int o_stride_d,
6373 const scalar_t * src, const scalar_t * cos,
6474 const scalar_t * sin, scalar_t * dst) {
65- int sq_id = blockIdx.x , b_id = blockIdx.y ;
66- int offset_block = sq_id * b * np * hn + b_id * np * hn;
75+ int s_id = blockIdx.x , b_id = blockIdx.y ;
76+ int offset_block = s_id * stride_s + b_id * stride_b;
77+ int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
6778#pragma unroll
68- for (int hn_id = threadIdx.x ; hn_id < hn2; hn_id += blockDim.x ) {
69- scalar_t v_cos = cos[sq_id * hn2 + hn_id ];
70- scalar_t v_sin = (hn_id + hn2 / 2 < hn2 )
71- ? sin[sq_id * hn2 + hn_id + hn2 / 2 ]
72- : -sin[sq_id * hn2 + hn_id + hn2 / 2 - hn2 ];
79+ for (int d_id = threadIdx.x ; d_id < d2; d_id += blockDim.x ) {
80+ scalar_t v_cos = cos[s_id * d2 + d_id ];
81+ scalar_t v_sin = (d_id + d2 / 2 < d2 )
82+ ? sin[s_id * d2 + d_id + d2 / 2 ]
83+ : -sin[s_id * d2 + d_id + d2 / 2 - d2 ];
7384#pragma unroll
74- for (int head_id = threadIdx.y ; head_id < np; head_id += blockDim.y ) {
75- int offset_src_dst = offset_block + head_id * hn + hn_id;
76- scalar_t v_src = src[offset_src_dst];
77- scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2)
78- ? src[offset_src_dst + hn2 / 2 ]
79- : src[offset_src_dst + hn2 / 2 - hn2];
80- dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin;
85+ for (int h_id = threadIdx.y ; h_id < h; h_id += blockDim.y ) {
86+ int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
87+ int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
88+ scalar_t v_src = src[offset_src];
89+ scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
90+ ? src[offset_src + (d2 / 2 ) * stride_d]
91+ : src[offset_src + (d2 / 2 - d2) * stride_d];
92+ dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
8193 }
8294 }
8395
8496 // handle the tail
85- if (hn > hn2 ) {
97+ if (d > d2 ) {
8698#pragma unroll
87- for (int head_id = threadIdx.y ; head_id < np; head_id += blockDim.y ) {
88- int offset_head = offset_block + head_id * hn;
99+ for (int h_id = threadIdx.y ; h_id < h; h_id += blockDim.y ) {
100+ int offset_head = offset_block + h_id * stride_h;
101+ int offset_head_dst = offset_block_dst + h_id * o_stride_h;
89102#pragma unroll
90- for (int hn_id = hn2 + threadIdx.x ; hn_id < hn; hn_id += blockDim.x ) {
91- dst[offset_head + hn_id ] = src[offset_head + hn_id ];
103+ for (int d_id = d2 + threadIdx.x ; d_id < d; d_id += blockDim.x ) {
104+ dst[offset_head_dst + d_id * o_stride_d ] = src[offset_head + d_id * stride_d ];
92105 }
93106 }
94107 }
@@ -97,32 +110,40 @@ __global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2,
97110} // end of anonymous namespace
98111
99112template <typename scalar_t >
100- void dispatch_fused_rope_forward (int sq, int b, int np, int hn, int hn2,
113+ void dispatch_fused_rope_forward (int s, int b, int h, int d, int d2,
114+ int stride_s, int stride_b, int stride_h,
115+ int stride_d, int o_stride_s, int o_stride_b,
116+ int o_stride_h, int o_stride_d,
101117 const scalar_t * input, const scalar_t * cos,
102118 const scalar_t * sin, scalar_t * output) {
103119 auto stream = at::cuda::getCurrentCUDAStream ();
104120
105- int warps_per_block = np < 16 ? 4 : 8 ;
106- dim3 blocks (sq , b);
121+ int warps_per_block = h < 16 ? 4 : 8 ;
122+ dim3 blocks (s , b);
107123 dim3 threads (C10_WARP_SIZE, warps_per_block);
108124
109- fused_rope_forward<<<blocks, threads, 0 , stream>>>(sq, b, np, hn, hn2, input,
110- cos, sin, output);
125+ fused_rope_forward<<<blocks, threads, 0 , stream>>>(
126+ h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
127+ o_stride_h, o_stride_d, input, cos, sin, output);
111128 C10_CUDA_KERNEL_LAUNCH_CHECK ();
112129}
113130
114131template <typename scalar_t >
115- void dispatch_fused_rope_backward (int sq, int b, int np, int hn, int hn2,
132+ void dispatch_fused_rope_backward (int s, int b, int h, int d, int d2,
133+ int stride_s, int stride_b, int stride_h,
134+ int stride_d, int o_stride_s, int o_stride_b,
135+ int o_stride_h, int o_stride_d,
116136 const scalar_t * output_grads,
117137 const scalar_t * cos, const scalar_t * sin,
118138 scalar_t * input_grads) {
119139 auto stream = at::cuda::getCurrentCUDAStream ();
120140
121- int warps_per_block = np < 16 ? 4 : 8 ;
122- dim3 blocks (sq , b);
141+ int warps_per_block = h < 16 ? 4 : 8 ;
142+ dim3 blocks (s , b);
123143 dim3 threads (C10_WARP_SIZE, warps_per_block);
124144
125145 fused_rope_backward<<<blocks, threads, 0 , stream>>>(
126- sq, b, np, hn, hn2, output_grads, cos, sin, input_grads);
146+ h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
147+ o_stride_h, o_stride_d, output_grads, cos, sin, input_grads);
127148 C10_CUDA_KERNEL_LAUNCH_CHECK ();
128149}
0 commit comments