1818
1919namespace fused_rope {
2020
21- torch::Tensor fwd_cuda (const torch::Tensor &input, const torch::Tensor &cos ,
22- const torch::Tensor &sin, const bool transpose_output);
21+ torch::Tensor fwd_cuda (const torch::Tensor &input, const torch::Tensor &freqs ,
22+ const bool transpose_output);
2323
2424torch::Tensor bwd_cuda (const torch::Tensor &output_grads,
25- const torch::Tensor &cos, const torch::Tensor &sin,
26- const bool transpose_output);
25+ const torch::Tensor &freqs, const bool transpose_output);
26+
27+ torch::Tensor fwd_cached_cuda (const torch::Tensor &input,
28+ const torch::Tensor &cos,
29+ const torch::Tensor &sin,
30+ const bool transpose_output);
31+
32+ torch::Tensor bwd_cached_cuda (const torch::Tensor &output_grads,
33+ const torch::Tensor &cos,
34+ const torch::Tensor &sin,
35+ const bool transpose_output);
2736
28- torch::Tensor fwd (const at::Tensor &input, const at::Tensor &cos,
29- const at::Tensor &sin, const bool transpose_output) {
37+ torch::Tensor fwd (const at::Tensor &input, const at::Tensor &freqs,
38+ const bool transpose_output) {
39+ TORCH_CHECK (input.dim () == 4 , " expected 4D tensor" );
40+ TORCH_CHECK (freqs.dim () == 4 , " expected 4D tensor" );
41+ TORCH_CHECK (input.size (0 ) == freqs.size (0 ),
42+ " expected input and freqs tensor have the same sequence length" );
43+ TORCH_CHECK (freqs.size (1 ) == 1 && freqs.size (2 ) == 1 ,
44+ " expected the second and third dims of the freqs tensor equal 1" );
45+ TORCH_CHECK (input.size (3 ) >= freqs.size (3 ),
46+ " expected the last dim of the input tensor equals or is "
47+ " greater than the freqs tensor" );
48+ TORCH_CHECK (freqs.scalar_type () == at::ScalarType::Float,
49+ " Dtype of the freqs tensor must be float" );
50+
51+ return fwd_cuda (input, freqs, transpose_output);
52+ }
53+
54+ torch::Tensor bwd (const torch::Tensor &output_grads, const at::Tensor &freqs,
55+ const bool transpose_output) {
56+ TORCH_CHECK (output_grads.dim () == 4 , " expected 4D tensor" );
57+ TORCH_CHECK (freqs.dim () == 4 , " expected 4D tensor" );
58+ TORCH_CHECK (
59+ output_grads.size (0 ) == freqs.size (0 ),
60+ " expected output_grads and freqs tensor have the same sequence length" );
61+ TORCH_CHECK (freqs.size (1 ) == 1 && freqs.size (2 ) == 1 ,
62+ " expected the second and third dims of the freqs tensor equal 1" );
63+ TORCH_CHECK (output_grads.size (3 ) >= freqs.size (3 ),
64+ " expected the last dim of the output_grads tensor equals or is "
65+ " greater than the freqs tensor" );
66+ TORCH_CHECK (freqs.scalar_type () == at::ScalarType::Float,
67+ " Dtype of the freqs tensor must be float" );
68+
69+ return bwd_cuda (output_grads, freqs, transpose_output);
70+ }
71+
72+ torch::Tensor fwd_cached (const at::Tensor &input, const at::Tensor &cos,
73+ const at::Tensor &sin, const bool transpose_output) {
3074 TORCH_CHECK (input.dim () == 4 , " expected 4D tensor" );
3175 TORCH_CHECK (cos.dim () == 4 , " expected 4D tensor" );
3276 TORCH_CHECK (sin.dim () == 4 , " expected 4D tensor" );
@@ -38,18 +82,20 @@ torch::Tensor fwd(const at::Tensor &input, const at::Tensor &cos,
3882 " expected the second and third dims of the cos tensor equal 1" );
3983 TORCH_CHECK (sin.size (1 ) == 1 && sin.size (2 ) == 1 ,
4084 " expected the second and third dims of the sin tensor equal 1" );
85+ TORCH_CHECK (cos.size (3 ) == sin.size (3 ),
86+ " expected cos and sin tensor have the same last dim" );
4187 TORCH_CHECK (input.size (3 ) >= cos.size (3 ),
42- " expected the last dim of the input tensor is greater than the "
43- " cos tensor" );
44- TORCH_CHECK (input.size (3 ) >= sin.size (3 ),
45- " expected the last dim of the input tensor is greater than the "
46- " sin tensor" );
88+ " expected the last dim of the input tensor equals or is "
89+ " greater than the cos tensor" );
90+ TORCH_CHECK (cos.scalar_type () == sin.scalar_type (),
91+ " expected cos and sin tensor have the same dtype" );
4792
48- return fwd_cuda (input, cos, sin, transpose_output);
93+ return fwd_cached_cuda (input, cos, sin, transpose_output);
4994}
5095
51- torch::Tensor bwd (const torch::Tensor &output_grads, const at::Tensor &cos,
52- const at::Tensor &sin, const bool transpose_output) {
96+ torch::Tensor bwd_cached (const torch::Tensor &output_grads,
97+ const at::Tensor &cos, const at::Tensor &sin,
98+ const bool transpose_output) {
5399 TORCH_CHECK (output_grads.dim () == 4 , " expected 4D tensor" );
54100 TORCH_CHECK (cos.dim () == 4 , " expected 4D tensor" );
55101 TORCH_CHECK (sin.dim () == 4 , " expected 4D tensor" );
@@ -63,16 +109,15 @@ torch::Tensor bwd(const torch::Tensor &output_grads, const at::Tensor &cos,
63109 " expected the second and third dims of the cos tensor equal 1" );
64110 TORCH_CHECK (sin.size (1 ) == 1 && sin.size (2 ) == 1 ,
65111 " expected the second and third dims of the sin tensor equal 1" );
66- TORCH_CHECK (
67- output_grads.size (3 ) >= cos.size (3 ),
68- " expected the last dim of the output_grads tensor is greater than the "
69- " cos tensor" );
70- TORCH_CHECK (
71- output_grads.size (3 ) >= sin.size (3 ),
72- " expected the last dim of the output_grads tensor is greater than the "
73- " sin tensor" );
112+ TORCH_CHECK (cos.size (3 ) == sin.size (3 ),
113+ " expected cos and sin tensor have the same last dim" );
114+ TORCH_CHECK (output_grads.size (3 ) >= cos.size (3 ),
115+ " expected the last dim of the output_grads tensor equals or is "
116+ " greater than the cos tensor" );
117+ TORCH_CHECK (cos.scalar_type () == sin.scalar_type (),
118+ " expected cos and sin tensor have the same dtype" );
74119
75- return bwd_cuda (output_grads, cos, sin, transpose_output);
120+ return bwd_cached_cuda (output_grads, cos, sin, transpose_output);
76121}
77122
78123} // end namespace fused_rope
@@ -82,4 +127,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
82127 " Fused Rotary Positional Embedding -- Forward." );
83128 m.def (" backward" , &fused_rope::bwd,
84129 " Fused Rotary Positional Embedding -- Backward." );
130+ m.def (" forward_cached" , &fused_rope::fwd_cached,
131+ " Fused Rotary Positional Embedding Cached -- Forward." );
132+ m.def (" backward_cached" , &fused_rope::bwd_cached,
133+ " Fused Rotary Positional Embedding Cached -- Backward." );
85134}
0 commit comments