|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <unordered_map> |
| 4 | +#include <cuda_fp16.h> |
| 5 | +#include <cuda_bf16.h> |
| 6 | + |
| 7 | +namespace layer_norm { |
| 8 | + |
| 9 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 10 | + |
| 11 | +template<typename Params> |
| 12 | +struct LaunchParams{ |
| 13 | + |
| 14 | + size_t workspace_bytes; |
| 15 | + size_t barrier_size; |
| 16 | + |
| 17 | + cudaDeviceProp * props; |
| 18 | + |
| 19 | + cudaStream_t stream; |
| 20 | + |
| 21 | + Params params; |
| 22 | + |
| 23 | +}; |
| 24 | + |
| 25 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 26 | + |
| 27 | +struct ParamsBase { |
| 28 | + ParamsBase() |
| 29 | + : ctas_per_col(0) |
| 30 | + , rows(0) |
| 31 | + , cols(0) |
| 32 | + , x(nullptr) |
| 33 | + , mu(nullptr) |
| 34 | + , rs(nullptr) |
| 35 | + , gamma(nullptr) |
| 36 | + , workspace(nullptr) |
| 37 | + , barrier(nullptr) |
| 38 | + { |
| 39 | + } |
| 40 | + |
| 41 | + // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. |
| 42 | + int ctas_per_col; |
| 43 | + |
| 44 | + // Input is interpreted as matrix. We normalize across columns. |
| 45 | + int rows; |
| 46 | + int cols; |
| 47 | + |
| 48 | + // Common data pointers. |
| 49 | + void *x; |
| 50 | + void *mu; |
| 51 | + void *rs; |
| 52 | + void *gamma; |
| 53 | + |
| 54 | + // Multi-CTA workspace in gmem. |
| 55 | + void *workspace; |
| 56 | + |
| 57 | + // Multi-CTA sync barriers in gmem. |
| 58 | + int *barrier; |
| 59 | + |
| 60 | +}; |
| 61 | + |
| 62 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 63 | + |
| 64 | +struct FwdParams : public ParamsBase { |
| 65 | + FwdParams() |
| 66 | + : ParamsBase() |
| 67 | + , z(nullptr) |
| 68 | + , beta(nullptr) |
| 69 | + , epsilon(0.f) |
| 70 | + { |
| 71 | + } |
| 72 | + |
| 73 | + // Output of LN FWD. |
| 74 | + void *z; |
| 75 | + void *beta; |
| 76 | + float epsilon; |
| 77 | + |
| 78 | +}; |
| 79 | + |
| 80 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 81 | + |
| 82 | +struct BwdParams : public ParamsBase { |
| 83 | + BwdParams() |
| 84 | + : ParamsBase() |
| 85 | + , dz(nullptr) |
| 86 | + , dbeta_part(nullptr) |
| 87 | + , dgamma_part(nullptr) |
| 88 | + , dx(nullptr) |
| 89 | + , dbeta(nullptr) |
| 90 | + , dgamma(nullptr) |
| 91 | + { |
| 92 | + } |
| 93 | + |
| 94 | + // Input: gradient wrt. LN FWD output. |
| 95 | + void *dz; |
| 96 | + |
| 97 | + // Workspace for Wgrad pre-reduction. |
| 98 | + void *dbeta_part; |
| 99 | + void *dgamma_part; |
| 100 | + |
| 101 | + // Output: Dgrad. |
| 102 | + void *dx; |
| 103 | + // Output: Wgrad. |
| 104 | + void *dbeta; |
| 105 | + void *dgamma; |
| 106 | + |
| 107 | +}; |
| 108 | + |
| 109 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 110 | + |
| 111 | +using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>; |
| 112 | +using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>; |
| 113 | +using FunctionKey = uint64_t; |
| 114 | +using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>; |
| 115 | +using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>; |
| 116 | + |
| 117 | +extern FwdRegistry FWD_FUNCS; |
| 118 | +extern BwdRegistry BWD_FUNCS; |
| 119 | + |
| 120 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 121 | + |
| 122 | +using fp32 = float; |
| 123 | +using fp16 = half; |
| 124 | +using bf16 = nv_bfloat16; |
| 125 | + |
| 126 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 127 | + |
| 128 | +template<typename T> |
| 129 | +struct TypeId{}; |
| 130 | + |
| 131 | +template<> |
| 132 | +struct TypeId<fp16>{ |
| 133 | + constexpr static uint32_t Value = 0; |
| 134 | +}; |
| 135 | + |
| 136 | +template<> |
| 137 | +struct TypeId<bf16>{ |
| 138 | + constexpr static uint32_t Value = 1; |
| 139 | +}; |
| 140 | + |
| 141 | +template<> |
| 142 | +struct TypeId<fp32>{ |
| 143 | + constexpr static uint32_t Value = 2; |
| 144 | +}; |
| 145 | + |
| 146 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 147 | + |
| 148 | +template<typename T, int S> |
| 149 | +struct Type2Key{ |
| 150 | + constexpr static uint32_t Value = TypeId<T>::Value << S; |
| 151 | +}; |
| 152 | + |
| 153 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 154 | + |
| 155 | +template<typename T> |
| 156 | +struct WeightType2Key : public Type2Key<T, 0>{}; |
| 157 | + |
| 158 | +template<typename T> |
| 159 | +struct InputType2Key : public Type2Key<T, 2>{}; |
| 160 | + |
| 161 | +template<typename T> |
| 162 | +struct OutputType2Key : public Type2Key<T, 4>{}; |
| 163 | + |
| 164 | +template<typename T> |
| 165 | +struct ComputeType2Key : public Type2Key<T, 6>{}; |
| 166 | + |
| 167 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 168 | + |
| 169 | +template<typename W, typename I, typename O, typename C> |
| 170 | +struct Types2Key{ |
| 171 | + constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value; |
| 172 | + constexpr static inline uint64_t get(const uint64_t hidden_size){ |
| 173 | + constexpr uint64_t type_key = Value; |
| 174 | + return (type_key << 32) | hidden_size; |
| 175 | + } |
| 176 | +}; |
| 177 | + |
| 178 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 179 | + |
| 180 | +template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> |
| 181 | +struct FwdRegistrar{ |
| 182 | + FwdRegistrar(FwdFunction f){ |
| 183 | + uint64_t key = Types2Key<W,I,O,C>::get(HIDDEN_SIZE); |
| 184 | + FWD_FUNCS.insert({ key, f }); |
| 185 | + } |
| 186 | +}; |
| 187 | + |
| 188 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 189 | + |
| 190 | +template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> |
| 191 | +struct BwdRegistrar{ |
| 192 | + BwdRegistrar(BwdFunction f){ |
| 193 | + uint64_t key = Types2Key<W,I,O,C>::get(HIDDEN_SIZE); |
| 194 | + BWD_FUNCS.insert({ key, f }); |
| 195 | + } |
| 196 | +}; |
| 197 | + |
| 198 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 199 | + |
| 200 | +} // namespace layer_norm |
0 commit comments