66namespace {
77void compute_n1_n2 (
88 at::Tensor input,
9- #ifdef VERSION_GE_1_1
109 at::IntArrayRef normalized_shape,
11- #else
12- at::IntList normalized_shape,
13- #endif
1410 int & n1,
1511 int & n2)
1612{
@@ -27,11 +23,7 @@ void compute_n1_n2(
2723}
2824
2925void check_args (
30- #ifdef VERSION_GE_1_1
3126 at::IntArrayRef normalized_shape,
32- #else
33- at::IntList normalized_shape,
34- #endif
3527 at::Tensor gamma,
3628 at::Tensor beta
3729 )
@@ -41,11 +33,7 @@ void check_args(
4133}
4234
4335void check_args (
44- #ifdef VERSION_GE_1_1
4536 at::IntArrayRef normalized_shape,
46- #else
47- at::IntList normalized_shape,
48- #endif
4937 at::Tensor gamma
5038 )
5139{
@@ -55,11 +43,7 @@ void check_args(
5543
5644void check_args (
5745 at::Tensor input,
58- #ifdef VERSION_GE_1_1
5946 at::IntArrayRef normalized_shape,
60- #else
61- at::IntList normalized_shape,
62- #endif
6347 int & n1,
6448 int & n2
6549 )
@@ -94,11 +78,7 @@ void check_args(
9478
9579void check_args (
9680 at::Tensor input,
97- #ifdef VERSION_GE_1_1
9881 at::IntArrayRef normalized_shape,
99- #else
100- at::IntList normalized_shape,
101- #endif
10282 at::Tensor gamma,
10383 at::Tensor beta,
10484 int & n1,
@@ -111,11 +91,7 @@ void check_args(
11191
11292void check_args (
11393 at::Tensor input,
114- #ifdef VERSION_GE_1_1
11594 at::IntArrayRef normalized_shape,
116- #else
117- at::IntList normalized_shape,
118- #endif
11995 at::Tensor gamma,
12096 int & n1,
12197 int & n2
@@ -133,11 +109,7 @@ void cuda_layer_norm(
133109 at::Tensor* input,
134110 int n1,
135111 int n2,
136- #ifdef VERSION_GE_1_1
137112 at::IntArrayRef normalized_shape,
138- #else
139- at::IntList normalized_shape,
140- #endif
141113 at::Tensor* gamma,
142114 at::Tensor* beta,
143115 double epsilon);
@@ -148,11 +120,7 @@ void cuda_layer_norm(
148120
149121std::vector<at::Tensor> layer_norm (
150122 at::Tensor input,
151- #ifdef VERSION_GE_1_1
152123 at::IntArrayRef normalized_shape,
153- #else
154- at::IntList normalized_shape,
155- #endif
156124 double epsilon) {
157125 CHECK_INPUT (input);
158126 int n1,n2;
@@ -167,11 +135,7 @@ std::vector<at::Tensor> layer_norm(
167135
168136std::vector<at::Tensor> layer_norm_affine (
169137 at::Tensor input,
170- #ifdef VERSION_GE_1_1
171138 at::IntArrayRef normalized_shape,
172- #else
173- at::IntList normalized_shape,
174- #endif
175139 at::Tensor gamma,
176140 at::Tensor beta,
177141 double epsilon) {
@@ -191,11 +155,7 @@ std::vector<at::Tensor> layer_norm_affine(
191155
192156std::vector<at::Tensor> layer_norm_affine_mixed_dtypes (
193157 at::Tensor input,
194- #ifdef VERSION_GE_1_1
195158 at::IntArrayRef normalized_shape,
196- #else
197- at::IntList normalized_shape,
198- #endif
199159 at::Tensor gamma,
200160 at::Tensor beta,
201161 double epsilon) {
@@ -217,11 +177,7 @@ void cuda_layer_norm_gradient(
217177 at::Tensor* input_or_output,
218178 int n1,
219179 int n2,
220- #ifdef VERSION_GE_1_1
221180 at::IntArrayRef normalized_shape,
222- #else
223- at::IntList normalized_shape,
224- #endif
225181 at::Tensor* gamma,
226182 at::Tensor* beta,
227183 double epsilon,
@@ -236,11 +192,7 @@ at::Tensor layer_norm_gradient(
236192 c10::optional<at::Tensor> mean_,
237193 at::Tensor invvar,
238194 at::Tensor input_or_output,
239- #ifdef VERSION_GE_1_1
240195 at::IntArrayRef normalized_shape,
241- #else
242- at::IntList normalized_shape,
243- #endif
244196 double epsilon,
245197 bool memory_efficient) {
246198 CHECK_INPUT (dout);
@@ -266,11 +218,7 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
266218 c10::optional<at::Tensor> mean_,
267219 at::Tensor invvar,
268220 at::Tensor input_or_output,
269- #ifdef VERSION_GE_1_1
270221 at::IntArrayRef normalized_shape,
271- #else
272- at::IntList normalized_shape,
273- #endif
274222 at::Tensor gamma,
275223 at::Tensor beta,
276224 double epsilon,
@@ -304,11 +252,7 @@ void cuda_rms_norm(
304252 at::Tensor* input,
305253 int n1,
306254 int n2,
307- #ifdef VERSION_GE_1_1
308255 at::IntArrayRef normalized_shape,
309- #else
310- at::IntList normalized_shape,
311- #endif
312256 at::Tensor* gamma,
313257 double epsilon);
314258
@@ -318,11 +262,7 @@ void cuda_rms_norm(
318262
319263std::vector<at::Tensor> rms_norm (
320264 at::Tensor input,
321- #ifdef VERSION_GE_1_1
322265 at::IntArrayRef normalized_shape,
323- #else
324- at::IntList normalized_shape,
325- #endif
326266 double epsilon) {
327267 CHECK_INPUT (input);
328268 int n1,n2;
@@ -336,11 +276,7 @@ std::vector<at::Tensor> rms_norm(
336276
337277std::vector<at::Tensor> rms_norm_affine (
338278 at::Tensor input,
339- #ifdef VERSION_GE_1_1
340279 at::IntArrayRef normalized_shape,
341- #else
342- at::IntList normalized_shape,
343- #endif
344280 at::Tensor gamma,
345281 double epsilon) {
346282 CHECK_INPUT (input);
@@ -357,11 +293,7 @@ std::vector<at::Tensor> rms_norm_affine(
357293
358294std::vector<at::Tensor> rms_norm_affine_mixed_dtypes (
359295 at::Tensor input,
360- #ifdef VERSION_GE_1_1
361296 at::IntArrayRef normalized_shape,
362- #else
363- at::IntList normalized_shape,
364- #endif
365297 at::Tensor gamma,
366298 double epsilon) {
367299 CHECK_INPUT (input);
@@ -381,11 +313,7 @@ void cuda_rms_norm_gradient(
381313 at::Tensor* input_or_output,
382314 int n1,
383315 int n2,
384- #ifdef VERSION_GE_1_1
385316 at::IntArrayRef normalized_shape,
386- #else
387- at::IntList normalized_shape,
388- #endif
389317 at::Tensor* gamma,
390318 double epsilon,
391319 at::Tensor* grad_input,
@@ -396,11 +324,7 @@ at::Tensor rms_norm_gradient(
396324 at::Tensor dout,
397325 at::Tensor invvar,
398326 at::Tensor input_or_output,
399- #ifdef VERSION_GE_1_1
400327 at::IntArrayRef normalized_shape,
401- #else
402- at::IntList normalized_shape,
403- #endif
404328 double epsilon,
405329 bool memory_efficient) {
406330 CHECK_INPUT (dout);
@@ -419,11 +343,7 @@ std::vector<at::Tensor> rms_norm_gradient_affine(
419343 at::Tensor dout,
420344 at::Tensor invvar,
421345 at::Tensor input_or_output,
422- #ifdef VERSION_GE_1_1
423346 at::IntArrayRef normalized_shape,
424- #else
425- at::IntList normalized_shape,
426- #endif
427347 at::Tensor gamma,
428348 double epsilon,
429349 bool memory_efficient) {
0 commit comments