Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 37d83fc

Browse files
authored
[FusedRoPE] Fuse type conversion and cos/sin (#1752)
* minor fix * fuse type conversion Signed-off-by: Xin Yao <[email protected]> * fuse cos/sin Signed-off-by: Xin Yao <[email protected]> * update comments Signed-off-by: Xin Yao <[email protected]> * fix typo Signed-off-by: Xin Yao <[email protected]> * lint Signed-off-by: Xin Yao <[email protected]> * use TORCH_CHECK instead of AT_ERROR Signed-off-by: Xin Yao <[email protected]> --------- Signed-off-by: Xin Yao <[email protected]>
1 parent 7548f68 commit 37d83fc

5 files changed

Lines changed: 481 additions & 95 deletions

File tree

apex/transformer/functional/fused_rope.py

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,27 @@
1717

1818

1919
class FusedRoPEFunc(torch.autograd.Function):
20-
"""Fused RoPE function"""
20+
"""
21+
Fused RoPE function
22+
23+
This implementation assumes the input tensor to be in `sbhd` format and the RoPE tensor to be
24+
of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive
25+
`.contiguous()` calls, thus it may not achieve the best memory access pattern.
26+
"""
2127

2228
@staticmethod
2329
def forward(
2430
ctx,
2531
t: torch.Tensor,
26-
cos_: torch.Tensor,
27-
sin_: torch.Tensor,
32+
freqs: torch.Tensor,
2833
transpose_output_memory: bool = False,
2934
) -> torch.Tensor:
3035
import fused_rotary_positional_embedding
3136

3237
output = fused_rotary_positional_embedding.forward(
33-
t, cos_, sin_, transpose_output_memory
38+
t, freqs, transpose_output_memory
3439
)
35-
ctx.save_for_backward(cos_, sin_)
40+
ctx.save_for_backward(freqs)
3641
ctx.transpose_output_memory = transpose_output_memory
3742

3843
return output
@@ -43,12 +48,12 @@ def backward(
4348
) -> Tuple[Union[torch.Tensor, None], ...]:
4449
import fused_rotary_positional_embedding
4550

46-
cos_, sin_ = ctx.saved_tensors
51+
(freqs,) = ctx.saved_tensors
4752
grad_input = fused_rotary_positional_embedding.backward(
48-
grad_output, cos_, sin_, ctx.transpose_output_memory
53+
grad_output, freqs, ctx.transpose_output_memory
4954
)
5055

51-
return grad_input, None, None, None
56+
return grad_input, None, None
5257

5358

5459
def fused_apply_rotary_pos_emb(
@@ -59,39 +64,79 @@ def fused_apply_rotary_pos_emb(
5964
"""Apply rotary positional embedding to input tensor T.
6065
6166
Args:
62-
t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
63-
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]
67+
t (Tensor): Input tensor T is of shape [s, b, h, d]
68+
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [s, 1, 1, d] and
69+
`float` dtype
6470
transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b'
6571
dimension of the output's underlying memory format. This is very helpful when you want to
6672
get a contiguous tensor after calling `output.transpose(0, 1)`.
6773
6874
Returns:
6975
Tensor: The input tensor after applying RoPE
7076
"""
71-
cos_ = torch.cos(freqs).to(t.dtype)
72-
sin_ = torch.sin(freqs).to(t.dtype)
73-
return FusedRoPEFunc.apply(t, cos_, sin_, transpose_output_memory)
77+
return FusedRoPEFunc.apply(t, freqs, transpose_output_memory)
78+
79+
80+
class FusedRoPECachedFunc(torch.autograd.Function):
81+
"""
82+
Fused RoPE function
83+
84+
This implementation assumes the input tensor to be in `sbhd` format and the RoPE tensor to be
85+
of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid the expensive
86+
`.contiguous()` calls, thus it may not achieve the best memory access pattern.
87+
"""
88+
89+
@staticmethod
90+
def forward(
91+
ctx,
92+
t: torch.Tensor,
93+
cos_: torch.Tensor,
94+
sin_: torch.Tensor,
95+
transpose_output_memory: bool = False,
96+
) -> torch.Tensor:
97+
import fused_rotary_positional_embedding
98+
99+
output = fused_rotary_positional_embedding.forward_cached(
100+
t, cos_, sin_, transpose_output_memory
101+
)
102+
ctx.save_for_backward(cos_, sin_)
103+
ctx.transpose_output_memory = transpose_output_memory
104+
105+
return output
106+
107+
@staticmethod
108+
def backward(
109+
ctx, grad_output: torch.Tensor
110+
) -> Tuple[Union[torch.Tensor, None], ...]:
111+
import fused_rotary_positional_embedding
112+
113+
cos_, sin_ = ctx.saved_tensors
114+
grad_input = fused_rotary_positional_embedding.backward_cached(
115+
grad_output, cos_, sin_, ctx.transpose_output_memory
116+
)
117+
118+
return grad_input, None, None, None
74119

75120

76121
def fused_apply_rotary_pos_emb_cached(
77122
t: torch.Tensor,
78-
cos: torch.Tensor,
79-
sin: torch.Tensor,
123+
cos_: torch.Tensor,
124+
sin_: torch.Tensor,
80125
transpose_output_memory: bool = False,
81126
) -> torch.Tensor:
82127
"""Apply rotary positional embedding to input tensor T.
83128
84129
Args:
85-
t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
86-
cos (Tensor): Cached cosine of the rotary positional embedding tensor is of shape [seq_length, ..., dim]
87-
sin (Tensor): Cached sine of the rotary positional embedding tensor is of shape [seq_length, ..., dim]
130+
t (Tensor): Input tensor T is of shape [s, b, h, d]
131+
cos_ (Tensor): Cached cosine of the rotary positional embedding tensor is of
132+
shape [s, 1, 1, d] and dtype either `float` or the same as `t`.
133+
sin_ (Tensor): Cached sine of the rotary positional embedding tensor is of
134+
shape [s, 1, 1, d] and dtype either `float` or the same as `t`.
88135
transpose_output_memory (bool): Default to False. Whether to transpose the 's' and 'b'
89136
dimension of the output's underlying memory format. This is very helpful when you want to
90137
get a contiguous tensor after calling `output.transpose(0, 1)`.
91138
92139
Returns:
93140
Tensor: The input tensor after applying RoPE
94141
"""
95-
cos_ = cos.to(t.dtype)
96-
sin_ = sin.to(t.dtype)
97-
return FusedRoPEFunc.apply(t, cos_, sin_, transpose_output_memory)
142+
return FusedRoPECachedFunc.apply(t, cos_, sin_, transpose_output_memory)

csrc/megatron/fused_rotary_positional_embedding.cpp

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,59 @@
1818

1919
namespace fused_rope {
2020

21-
torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos,
22-
const torch::Tensor &sin, const bool transpose_output);
21+
torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &freqs,
22+
const bool transpose_output);
2323

2424
torch::Tensor bwd_cuda(const torch::Tensor &output_grads,
25-
const torch::Tensor &cos, const torch::Tensor &sin,
26-
const bool transpose_output);
25+
const torch::Tensor &freqs, const bool transpose_output);
26+
27+
torch::Tensor fwd_cached_cuda(const torch::Tensor &input,
28+
const torch::Tensor &cos,
29+
const torch::Tensor &sin,
30+
const bool transpose_output);
31+
32+
torch::Tensor bwd_cached_cuda(const torch::Tensor &output_grads,
33+
const torch::Tensor &cos,
34+
const torch::Tensor &sin,
35+
const bool transpose_output);
2736

28-
torch::Tensor fwd(const at::Tensor &input, const at::Tensor &cos,
29-
const at::Tensor &sin, const bool transpose_output) {
37+
torch::Tensor fwd(const at::Tensor &input, const at::Tensor &freqs,
38+
const bool transpose_output) {
39+
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
40+
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
41+
TORCH_CHECK(input.size(0) == freqs.size(0),
42+
"expected input and freqs tensor have the same sequence length");
43+
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
44+
"expected the second and third dims of the freqs tensor equal 1");
45+
TORCH_CHECK(input.size(3) >= freqs.size(3),
46+
"expected the last dim of the input tensor equals or is "
47+
"greater than the freqs tensor");
48+
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
49+
"Dtype of the freqs tensor must be float");
50+
51+
return fwd_cuda(input, freqs, transpose_output);
52+
}
53+
54+
torch::Tensor bwd(const torch::Tensor &output_grads, const at::Tensor &freqs,
55+
const bool transpose_output) {
56+
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
57+
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
58+
TORCH_CHECK(
59+
output_grads.size(0) == freqs.size(0),
60+
"expected output_grads and freqs tensor have the same sequence length");
61+
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
62+
"expected the second and third dims of the freqs tensor equal 1");
63+
TORCH_CHECK(output_grads.size(3) >= freqs.size(3),
64+
"expected the last dim of the output_grads tensor equals or is "
65+
"greater than the freqs tensor");
66+
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
67+
"Dtype of the freqs tensor must be float");
68+
69+
return bwd_cuda(output_grads, freqs, transpose_output);
70+
}
71+
72+
torch::Tensor fwd_cached(const at::Tensor &input, const at::Tensor &cos,
73+
const at::Tensor &sin, const bool transpose_output) {
3074
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
3175
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
3276
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
@@ -38,18 +82,20 @@ torch::Tensor fwd(const at::Tensor &input, const at::Tensor &cos,
3882
"expected the second and third dims of the cos tensor equal 1");
3983
TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1,
4084
"expected the second and third dims of the sin tensor equal 1");
85+
TORCH_CHECK(cos.size(3) == sin.size(3),
86+
"expected cos and sin tensor have the same last dim");
4187
TORCH_CHECK(input.size(3) >= cos.size(3),
42-
"expected the last dim of the input tensor is greater than the "
43-
"cos tensor");
44-
TORCH_CHECK(input.size(3) >= sin.size(3),
45-
"expected the last dim of the input tensor is greater than the "
46-
"sin tensor");
88+
"expected the last dim of the input tensor equals or is "
89+
"greater than the cos tensor");
90+
TORCH_CHECK(cos.scalar_type() == sin.scalar_type(),
91+
"expected cos and sin tensor have the same dtype");
4792

48-
return fwd_cuda(input, cos, sin, transpose_output);
93+
return fwd_cached_cuda(input, cos, sin, transpose_output);
4994
}
5095

51-
torch::Tensor bwd(const torch::Tensor &output_grads, const at::Tensor &cos,
52-
const at::Tensor &sin, const bool transpose_output) {
96+
torch::Tensor bwd_cached(const torch::Tensor &output_grads,
97+
const at::Tensor &cos, const at::Tensor &sin,
98+
const bool transpose_output) {
5399
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
54100
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
55101
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
@@ -63,16 +109,15 @@ torch::Tensor bwd(const torch::Tensor &output_grads, const at::Tensor &cos,
63109
"expected the second and third dims of the cos tensor equal 1");
64110
TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1,
65111
"expected the second and third dims of the sin tensor equal 1");
66-
TORCH_CHECK(
67-
output_grads.size(3) >= cos.size(3),
68-
"expected the last dim of the output_grads tensor is greater than the "
69-
"cos tensor");
70-
TORCH_CHECK(
71-
output_grads.size(3) >= sin.size(3),
72-
"expected the last dim of the output_grads tensor is greater than the "
73-
"sin tensor");
112+
TORCH_CHECK(cos.size(3) == sin.size(3),
113+
"expected cos and sin tensor have the same last dim");
114+
TORCH_CHECK(output_grads.size(3) >= cos.size(3),
115+
"expected the last dim of the output_grads tensor equals or is "
116+
"greater than the cos tensor");
117+
TORCH_CHECK(cos.scalar_type() == sin.scalar_type(),
118+
"expected cos and sin tensor have the same dtype");
74119

75-
return bwd_cuda(output_grads, cos, sin, transpose_output);
120+
return bwd_cached_cuda(output_grads, cos, sin, transpose_output);
76121
}
77122

78123
} // end namespace fused_rope
@@ -82,4 +127,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
82127
"Fused Rotary Positional Embedding -- Forward.");
83128
m.def("backward", &fused_rope::bwd,
84129
"Fused Rotary Positional Embedding -- Backward.");
130+
m.def("forward_cached", &fused_rope::fwd_cached,
131+
"Fused Rotary Positional Embedding Cached -- Forward.");
132+
m.def("backward_cached", &fused_rope::bwd_cached,
133+
"Fused Rotary Positional Embedding Cached -- Backward.");
85134
}

0 commit comments

Comments
 (0)