Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ae75763

Browse files
crcrparyjk21
andauthored
FastLayerNorm compat with autocast (#1203)
* Persistent LayerNorm: Multi-CTA Rewrite * autocast support Co-authored-by: Young-Jun Ko <[email protected]>
1 parent 63d5dd6 commit ae75763

12 files changed

Lines changed: 2418 additions & 922 deletions

File tree

apex/contrib/csrc/layer_norm/ln.h

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#pragma once
2+
3+
#include <unordered_map>
4+
#include <cuda_fp16.h>
5+
#include <cuda_bf16.h>
6+
7+
namespace layer_norm {
8+
9+
////////////////////////////////////////////////////////////////////////////////////////////////////
10+
11+
template<typename Params>
12+
struct LaunchParams{
13+
14+
size_t workspace_bytes;
15+
size_t barrier_size;
16+
17+
cudaDeviceProp * props;
18+
19+
cudaStream_t stream;
20+
21+
Params params;
22+
23+
};
24+
25+
////////////////////////////////////////////////////////////////////////////////////////////////////
26+
27+
struct ParamsBase {
28+
ParamsBase()
29+
: ctas_per_col(0)
30+
, rows(0)
31+
, cols(0)
32+
, x(nullptr)
33+
, mu(nullptr)
34+
, rs(nullptr)
35+
, gamma(nullptr)
36+
, workspace(nullptr)
37+
, barrier(nullptr)
38+
{
39+
}
40+
41+
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
42+
int ctas_per_col;
43+
44+
// Input is interpreted as matrix. We normalize across columns.
45+
int rows;
46+
int cols;
47+
48+
// Common data pointers.
49+
void *x;
50+
void *mu;
51+
void *rs;
52+
void *gamma;
53+
54+
// Multi-CTA workspace in gmem.
55+
void *workspace;
56+
57+
// Multi-CTA sync barriers in gmem.
58+
int *barrier;
59+
60+
};
61+
62+
////////////////////////////////////////////////////////////////////////////////////////////////////
63+
64+
struct FwdParams : public ParamsBase {
65+
FwdParams()
66+
: ParamsBase()
67+
, z(nullptr)
68+
, beta(nullptr)
69+
, epsilon(0.f)
70+
{
71+
}
72+
73+
// Output of LN FWD.
74+
void *z;
75+
void *beta;
76+
float epsilon;
77+
78+
};
79+
80+
////////////////////////////////////////////////////////////////////////////////////////////////////
81+
82+
struct BwdParams : public ParamsBase {
83+
BwdParams()
84+
: ParamsBase()
85+
, dz(nullptr)
86+
, dbeta_part(nullptr)
87+
, dgamma_part(nullptr)
88+
, dx(nullptr)
89+
, dbeta(nullptr)
90+
, dgamma(nullptr)
91+
{
92+
}
93+
94+
// Input: gradient wrt. LN FWD output.
95+
void *dz;
96+
97+
// Workspace for Wgrad pre-reduction.
98+
void *dbeta_part;
99+
void *dgamma_part;
100+
101+
// Output: Dgrad.
102+
void *dx;
103+
// Output: Wgrad.
104+
void *dbeta;
105+
void *dgamma;
106+
107+
};
108+
109+
////////////////////////////////////////////////////////////////////////////////////////////////////
110+
111+
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
112+
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
113+
using FunctionKey = uint64_t;
114+
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
115+
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
116+
117+
extern FwdRegistry FWD_FUNCS;
118+
extern BwdRegistry BWD_FUNCS;
119+
120+
////////////////////////////////////////////////////////////////////////////////////////////////////
121+
122+
using fp32 = float;
123+
using fp16 = half;
124+
using bf16 = nv_bfloat16;
125+
126+
////////////////////////////////////////////////////////////////////////////////////////////////////
127+
128+
template<typename T>
129+
struct TypeId{};
130+
131+
template<>
132+
struct TypeId<fp16>{
133+
constexpr static uint32_t Value = 0;
134+
};
135+
136+
template<>
137+
struct TypeId<bf16>{
138+
constexpr static uint32_t Value = 1;
139+
};
140+
141+
template<>
142+
struct TypeId<fp32>{
143+
constexpr static uint32_t Value = 2;
144+
};
145+
146+
////////////////////////////////////////////////////////////////////////////////////////////////////
147+
148+
template<typename T, int S>
149+
struct Type2Key{
150+
constexpr static uint32_t Value = TypeId<T>::Value << S;
151+
};
152+
153+
////////////////////////////////////////////////////////////////////////////////////////////////////
154+
155+
template<typename T>
156+
struct WeightType2Key : public Type2Key<T, 0>{};
157+
158+
template<typename T>
159+
struct InputType2Key : public Type2Key<T, 2>{};
160+
161+
template<typename T>
162+
struct OutputType2Key : public Type2Key<T, 4>{};
163+
164+
template<typename T>
165+
struct ComputeType2Key : public Type2Key<T, 6>{};
166+
167+
////////////////////////////////////////////////////////////////////////////////////////////////////
168+
169+
template<typename W, typename I, typename O, typename C>
170+
struct Types2Key{
171+
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
172+
constexpr static inline uint64_t get(const uint64_t hidden_size){
173+
constexpr uint64_t type_key = Value;
174+
return (type_key << 32) | hidden_size;
175+
}
176+
};
177+
178+
////////////////////////////////////////////////////////////////////////////////////////////////////
179+
180+
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
181+
struct FwdRegistrar{
182+
FwdRegistrar(FwdFunction f){
183+
uint64_t key = Types2Key<W,I,O,C>::get(HIDDEN_SIZE);
184+
FWD_FUNCS.insert({ key, f });
185+
}
186+
};
187+
188+
////////////////////////////////////////////////////////////////////////////////////////////////////
189+
190+
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
191+
struct BwdRegistrar{
192+
BwdRegistrar(BwdFunction f){
193+
uint64_t key = Types2Key<W,I,O,C>::get(HIDDEN_SIZE);
194+
BWD_FUNCS.insert({ key, f });
195+
}
196+
};
197+
198+
////////////////////////////////////////////////////////////////////////////////////////////////////
199+
200+
} // namespace layer_norm

0 commit comments

Comments
 (0)