Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit dc5fa38

Browse files
authored
Avoid .contiguous() in fused RoPE (#1751)
* avoid input.contiguous() in fused_rope Signed-off-by: Xin Yao <[email protected]> * add transpose_output_memory Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]>
1 parent a2f6683 commit dc5fa38

5 files changed

Lines changed: 213 additions & 93 deletions

File tree

apex/transformer/functional/fused_rope.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,23 @@
1717

1818

1919
class FusedRoPEFunc(torch.autograd.Function):
20+
"""Fused RoPE function"""
21+
2022
@staticmethod
2123
def forward(
22-
ctx, t: torch.Tensor, cos_: torch.Tensor, sin_: torch.Tensor
24+
ctx,
25+
t: torch.Tensor,
26+
cos_: torch.Tensor,
27+
sin_: torch.Tensor,
28+
transpose_output_memory: bool = False,
2329
) -> torch.Tensor:
2430
import fused_rotary_positional_embedding
2531

26-
output = fused_rotary_positional_embedding.forward(t, cos_, sin_)
32+
output = fused_rotary_positional_embedding.forward(
33+
t, cos_, sin_, transpose_output_memory
34+
)
2735
ctx.save_for_backward(cos_, sin_)
36+
ctx.transpose_output_memory = transpose_output_memory
2837

2938
return output
3039

@@ -35,39 +44,54 @@ def backward(
3544
import fused_rotary_positional_embedding
3645

3746
cos_, sin_ = ctx.saved_tensors
38-
grad_q = fused_rotary_positional_embedding.backward(grad_output, cos_, sin_)
47+
grad_input = fused_rotary_positional_embedding.backward(
48+
grad_output, cos_, sin_, ctx.transpose_output_memory
49+
)
3950

40-
return grad_q, None, None
51+
return grad_input, None, None, None
4152

4253

43-
def fused_apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
54+
def fused_apply_rotary_pos_emb(
55+
t: torch.Tensor,
56+
freqs: torch.Tensor,
57+
transpose_output_memory: bool = False,
58+
) -> torch.Tensor:
4459
"""Apply rotary positional embedding to input tensor T.
4560
4661
Args:
4762
t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
4863
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]
64+
transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b'
65+
dimension of the output's underlying memory format. This is very helpful when you want to
66+
get a contiguous tensor after calling `output.transpose(0, 1)`.
4967
5068
Returns:
5169
Tensor: The input tensor after applying RoPE
5270
"""
5371
cos_ = torch.cos(freqs).to(t.dtype)
5472
sin_ = torch.sin(freqs).to(t.dtype)
55-
return FusedRoPEFunc.apply(t, cos_, sin_)
73+
return FusedRoPEFunc.apply(t, cos_, sin_, transpose_output_memory)
5674

5775

5876
def fused_apply_rotary_pos_emb_cached(
59-
t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
77+
t: torch.Tensor,
78+
cos: torch.Tensor,
79+
sin: torch.Tensor,
80+
transpose_output_memory: bool = False,
6081
) -> torch.Tensor:
6182
"""Apply rotary positional embedding to input tensor T.
6283
6384
Args:
6485
t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
6586
cos (Tensor): Cached cosine of the rotary positional embedding tensor is of shape [seq_length, ..., dim]
6687
sin (Tensor): Cached sine of the rotary positional embedding tensor is of shape [seq_length, ..., dim]
88+
transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b'
89+
dimension of the output's underlying memory format. This is very helpful when you want to
90+
get a contiguous tensor after calling `output.transpose(0, 1)`.
6791
6892
Returns:
6993
Tensor: The input tensor after applying RoPE
7094
"""
7195
cos_ = cos.to(t.dtype)
7296
sin_ = sin.to(t.dtype)
73-
return FusedRoPEFunc.apply(t, cos_, sin_)
97+
return FusedRoPEFunc.apply(t, cos_, sin_, transpose_output_memory)

csrc/megatron/fused_rotary_positional_embedding.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,14 @@
1919
namespace fused_rope {
2020

2121
torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos,
22-
const torch::Tensor &sin);
22+
const torch::Tensor &sin, const bool transpose_output);
2323

2424
torch::Tensor bwd_cuda(const torch::Tensor &output_grads,
25-
const torch::Tensor &cos, const torch::Tensor &sin);
25+
const torch::Tensor &cos, const torch::Tensor &sin,
26+
const bool transpose_output);
2627

27-
torch::Tensor fwd(const at::Tensor &input_, const at::Tensor &cos_,
28-
const at::Tensor &sin_) {
29-
auto input = input_.contiguous();
30-
auto cos = cos_.contiguous();
31-
auto sin = sin_.contiguous();
28+
torch::Tensor fwd(const at::Tensor &input, const at::Tensor &cos,
29+
const at::Tensor &sin, const bool transpose_output) {
3230
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
3331
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
3432
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
@@ -47,14 +45,11 @@ torch::Tensor fwd(const at::Tensor &input_, const at::Tensor &cos_,
4745
"expected the last dim of the input tensor is greater than the "
4846
"sin tensor");
4947

50-
return fwd_cuda(input, cos, sin);
48+
return fwd_cuda(input, cos, sin, transpose_output);
5149
}
5250

53-
torch::Tensor bwd(const torch::Tensor &output_grads_, const at::Tensor &cos_,
54-
const at::Tensor &sin_) {
55-
auto output_grads = output_grads_.contiguous();
56-
auto cos = cos_.contiguous();
57-
auto sin = sin_.contiguous();
51+
torch::Tensor bwd(const torch::Tensor &output_grads, const at::Tensor &cos,
52+
const at::Tensor &sin, const bool transpose_output) {
5853
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
5954
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
6055
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
@@ -77,7 +72,7 @@ torch::Tensor bwd(const torch::Tensor &output_grads_, const at::Tensor &cos_,
7772
"expected the last dim of the output_grads tensor is greater than the "
7873
"sin tensor");
7974

80-
return bwd_cuda(output_grads, cos, sin);
75+
return bwd_cuda(output_grads, cos, sin, transpose_output);
8176
}
8277

8378
} // end namespace fused_rope

csrc/megatron/fused_rotary_positional_embedding.h

Lines changed: 68 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -25,70 +25,83 @@
2525
namespace {
2626

2727
template <typename scalar_t>
28-
__global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2,
28+
__global__ void fused_rope_forward(int h, int d, int d2, int stride_s,
29+
int stride_b, int stride_h, int stride_d,
30+
int o_stride_s, int o_stride_b,
31+
int o_stride_h, int o_stride_d,
2932
const scalar_t* src, const scalar_t* cos,
3033
const scalar_t* sin, scalar_t* dst) {
31-
int sq_id = blockIdx.x, b_id = blockIdx.y;
32-
int offset_block = sq_id * b * np * hn + b_id * np * hn;
34+
int s_id = blockIdx.x, b_id = blockIdx.y;
35+
int offset_block = s_id * stride_s + b_id * stride_b;
36+
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
3337
#pragma unroll
34-
for (int hn_id = threadIdx.x; hn_id < hn2; hn_id += blockDim.x) {
35-
scalar_t v_cos = cos[sq_id * hn2 + hn_id];
36-
scalar_t v_sin = sin[sq_id * hn2 + hn_id];
38+
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
39+
scalar_t v_cos = cos[s_id * d2 + d_id];
40+
scalar_t v_sin = sin[s_id * d2 + d_id];
3741
#pragma unroll
38-
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
39-
int offset_src_dst = offset_block + head_id * hn + hn_id;
40-
scalar_t v_src = src[offset_src_dst];
41-
scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2)
42-
? -src[offset_src_dst + hn2 / 2]
43-
: src[offset_src_dst + hn2 / 2 - hn2];
44-
dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin;
42+
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
43+
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
44+
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
45+
scalar_t v_src = src[offset_src];
46+
scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
47+
? -src[offset_src + (d2 / 2) * stride_d]
48+
: src[offset_src + (d2 / 2 - d2) * stride_d];
49+
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
4550
}
4651
}
4752

4853
// copy the rest
49-
if (hn > hn2) {
54+
if (d > d2) {
5055
#pragma unroll
51-
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
52-
int offset_head = offset_block + head_id * hn;
56+
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
57+
int offset_head = offset_block + h_id * stride_h;
58+
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
5359
#pragma unroll
54-
for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) {
55-
dst[offset_head + hn_id] = src[offset_head + hn_id];
60+
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
61+
dst[offset_head_dst + d_id * o_stride_d] =
62+
src[offset_head + d_id * stride_d];
5663
}
5764
}
5865
}
5966
}
6067

6168
template <typename scalar_t>
62-
__global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2,
69+
__global__ void fused_rope_backward(int h, int d, int d2, int stride_s,
70+
int stride_b, int stride_h, int stride_d,
71+
int o_stride_s, int o_stride_b,
72+
int o_stride_h, int o_stride_d,
6373
const scalar_t* src, const scalar_t* cos,
6474
const scalar_t* sin, scalar_t* dst) {
65-
int sq_id = blockIdx.x, b_id = blockIdx.y;
66-
int offset_block = sq_id * b * np * hn + b_id * np * hn;
75+
int s_id = blockIdx.x, b_id = blockIdx.y;
76+
int offset_block = s_id * stride_s + b_id * stride_b;
77+
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
6778
#pragma unroll
68-
for (int hn_id = threadIdx.x; hn_id < hn2; hn_id += blockDim.x) {
69-
scalar_t v_cos = cos[sq_id * hn2 + hn_id];
70-
scalar_t v_sin = (hn_id + hn2 / 2 < hn2)
71-
? sin[sq_id * hn2 + hn_id + hn2 / 2]
72-
: -sin[sq_id * hn2 + hn_id + hn2 / 2 - hn2];
79+
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
80+
scalar_t v_cos = cos[s_id * d2 + d_id];
81+
scalar_t v_sin = (d_id + d2 / 2 < d2)
82+
? sin[s_id * d2 + d_id + d2 / 2]
83+
: -sin[s_id * d2 + d_id + d2 / 2 - d2];
7384
#pragma unroll
74-
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
75-
int offset_src_dst = offset_block + head_id * hn + hn_id;
76-
scalar_t v_src = src[offset_src_dst];
77-
scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2)
78-
? src[offset_src_dst + hn2 / 2]
79-
: src[offset_src_dst + hn2 / 2 - hn2];
80-
dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin;
85+
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
86+
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
87+
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
88+
scalar_t v_src = src[offset_src];
89+
scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
90+
? src[offset_src + (d2 / 2) * stride_d]
91+
: src[offset_src + (d2 / 2 - d2) * stride_d];
92+
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
8193
}
8294
}
8395

8496
// handle the tail
85-
if (hn > hn2) {
97+
if (d > d2) {
8698
#pragma unroll
87-
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
88-
int offset_head = offset_block + head_id * hn;
99+
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
100+
int offset_head = offset_block + h_id * stride_h;
101+
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
89102
#pragma unroll
90-
for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) {
91-
dst[offset_head + hn_id] = src[offset_head + hn_id];
103+
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
104+
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];
92105
}
93106
}
94107
}
@@ -97,32 +110,40 @@ __global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2,
97110
} // end of anonymous namespace
98111

99112
template <typename scalar_t>
100-
void dispatch_fused_rope_forward(int sq, int b, int np, int hn, int hn2,
113+
void dispatch_fused_rope_forward(int s, int b, int h, int d, int d2,
114+
int stride_s, int stride_b, int stride_h,
115+
int stride_d, int o_stride_s, int o_stride_b,
116+
int o_stride_h, int o_stride_d,
101117
const scalar_t* input, const scalar_t* cos,
102118
const scalar_t* sin, scalar_t* output) {
103119
auto stream = at::cuda::getCurrentCUDAStream();
104120

105-
int warps_per_block = np < 16 ? 4 : 8;
106-
dim3 blocks(sq, b);
121+
int warps_per_block = h < 16 ? 4 : 8;
122+
dim3 blocks(s, b);
107123
dim3 threads(C10_WARP_SIZE, warps_per_block);
108124

109-
fused_rope_forward<<<blocks, threads, 0, stream>>>(sq, b, np, hn, hn2, input,
110-
cos, sin, output);
125+
fused_rope_forward<<<blocks, threads, 0, stream>>>(
126+
h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
127+
o_stride_h, o_stride_d, input, cos, sin, output);
111128
C10_CUDA_KERNEL_LAUNCH_CHECK();
112129
}
113130

114131
template <typename scalar_t>
115-
void dispatch_fused_rope_backward(int sq, int b, int np, int hn, int hn2,
132+
void dispatch_fused_rope_backward(int s, int b, int h, int d, int d2,
133+
int stride_s, int stride_b, int stride_h,
134+
int stride_d, int o_stride_s, int o_stride_b,
135+
int o_stride_h, int o_stride_d,
116136
const scalar_t* output_grads,
117137
const scalar_t* cos, const scalar_t* sin,
118138
scalar_t* input_grads) {
119139
auto stream = at::cuda::getCurrentCUDAStream();
120140

121-
int warps_per_block = np < 16 ? 4 : 8;
122-
dim3 blocks(sq, b);
141+
int warps_per_block = h < 16 ? 4 : 8;
142+
dim3 blocks(s, b);
123143
dim3 threads(C10_WARP_SIZE, warps_per_block);
124144

125145
fused_rope_backward<<<blocks, threads, 0, stream>>>(
126-
sq, b, np, hn, hn2, output_grads, cos, sin, input_grads);
146+
h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
147+
o_stride_h, o_stride_d, output_grads, cos, sin, input_grads);
127148
C10_CUDA_KERNEL_LAUNCH_CHECK();
128149
}

0 commit comments

Comments
 (0)