Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e2083df

Browse files
authored
fast layer norm (#1037)
1 parent a78ccf0 commit e2083df

10 files changed

Lines changed: 1091 additions & 1 deletion

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ apex.egg-info
22
dist
33
build
44
docs/build
5-
*~
5+
*~
6+
__pycache__
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#include <torch/extension.h>
2+
#include "ATen/cuda/CUDAContext.h"
3+
4+
void ln_fwd_cuda(at::Tensor &y, at::Tensor &mu, at::Tensor &rsigma,
5+
const at::Tensor &x, const at::Tensor &gamma,
6+
const at::Tensor &beta, const float epsilon, const int rows, const int cols,
7+
cudaStream_t stream);
8+
9+
void ln_bwd_cuda(at::Tensor &dx, at::Tensor &dgamma, at::Tensor &dbeta,
10+
const at::Tensor &dw, const at::Tensor &x,
11+
const at::Tensor &mu, const at::Tensor &rsigma,
12+
const at::Tensor &gamma, const int rows, const int cols, cudaStream_t stream);
13+
14+
15+
std::vector<at::Tensor> ln_fwd(const at::Tensor &x, // BxSxhidden_size
16+
const at::Tensor &gamma, // hidden_size
17+
const at::Tensor &beta, // hidden_size
18+
const float epsilon
19+
) {
20+
21+
TORCH_CHECK(x.is_cuda())
22+
TORCH_CHECK(gamma.is_cuda())
23+
TORCH_CHECK(beta.is_cuda())
24+
25+
TORCH_CHECK(x.is_contiguous());
26+
auto sizes = x.sizes();
27+
TORCH_CHECK(sizes.size() == 2);
28+
29+
const int rows = sizes[0];
30+
const int cols = sizes[1];
31+
32+
auto dtype = x.scalar_type();
33+
34+
TORCH_CHECK(gamma.dtype() == dtype);
35+
TORCH_CHECK(beta.dtype() == dtype);
36+
37+
TORCH_CHECK(gamma.sizes() == beta.sizes());
38+
TORCH_CHECK(gamma.numel() == cols);
39+
40+
TORCH_CHECK(epsilon >= 0.f);
41+
42+
auto stream = at::cuda::getCurrentCUDAStream().stream();
43+
44+
auto y = torch::empty_like(x);
45+
46+
auto opts = x.options();
47+
48+
auto mu = torch::empty({rows}, opts.dtype(torch::kFloat32));
49+
auto rsigma = torch::empty({rows}, opts.dtype(torch::kFloat32));
50+
51+
ln_fwd_cuda(y, mu, rsigma, x, gamma, beta, epsilon, rows, cols, stream);
52+
53+
return {y, mu, rsigma};
54+
}
55+
56+
57+
58+
std::vector<at::Tensor> ln_bwd(const at::Tensor &dw, // BxSxhidden_size
59+
const at::Tensor &x, // BxSxhidden_size
60+
const at::Tensor &mu, // BxS, FP32!
61+
const at::Tensor &rsigma, // BxS, FP32!
62+
const at::Tensor &gamma // hidden_size
63+
) {
64+
65+
TORCH_CHECK(x.is_cuda());
66+
TORCH_CHECK(dw.is_cuda());
67+
TORCH_CHECK(mu.is_cuda());
68+
TORCH_CHECK(rsigma.is_cuda());
69+
TORCH_CHECK(gamma.is_cuda());
70+
71+
TORCH_CHECK(x.is_contiguous());
72+
TORCH_CHECK(dw.is_contiguous());
73+
74+
auto sizes = x.sizes();
75+
TORCH_CHECK(sizes.size() == 2);
76+
TORCH_CHECK(dw.sizes() == sizes);
77+
auto rows = sizes[0];
78+
auto cols = sizes[1];
79+
80+
auto dtype = x.scalar_type();
81+
TORCH_CHECK(dw.dtype() == dtype);
82+
TORCH_CHECK(gamma.dtype() == dtype);
83+
TORCH_CHECK(mu.dtype() == torch::kFloat32);
84+
TORCH_CHECK(rsigma.dtype() == torch::kFloat32);
85+
TORCH_CHECK(mu.sizes() == rsigma.sizes());
86+
TORCH_CHECK(mu.numel() == rows);
87+
88+
TORCH_CHECK(gamma.numel() == cols);
89+
90+
91+
auto stream = at::cuda::getCurrentCUDAStream().stream();
92+
93+
auto dx = torch::empty_like(x);
94+
auto dgamma = torch::empty_like(gamma);
95+
auto dbeta = torch::empty_like(gamma);
96+
97+
ln_bwd_cuda(dx, dgamma, dbeta, dw, x, mu, rsigma, gamma, rows, cols, stream);
98+
99+
return {dx, dgamma, dbeta};
100+
}
101+
102+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
103+
m.doc() = "CUDA LayerNorm"; // optional module docstring
104+
m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel");
105+
m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel");
106+
}

0 commit comments

Comments
 (0)