|
5 | 5 | from torch.nn.parameter import Parameter |
6 | 6 | from torch.nn import init |
7 | 7 | from torch.nn import functional as F |
| 8 | +from typing import List, Tuple |
8 | 9 |
|
9 | 10 | from apex._autocast_utils import _cast_if_autocast_enabled |
10 | 11 |
|
@@ -91,6 +92,125 @@ def backward(ctx, grad_output): |
91 | 92 | return grad_input, grad_weight, None, None, None |
92 | 93 |
|
93 | 94 |
|
| 95 | +@torch.library.custom_op("apex::fused_rms_norm_affine_fwd", mutates_args=()) |
| 96 | +def fused_rms_norm_affine_fwd( |
| 97 | + input: torch.Tensor, |
| 98 | + weight: torch.Tensor, |
| 99 | + normalized_shape: List[int], |
| 100 | + eps: float, |
| 101 | + memory_efficient: bool = False, |
| 102 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 103 | + global fused_layer_norm_cuda |
| 104 | + if fused_layer_norm_cuda is None: |
| 105 | + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") |
| 106 | + |
| 107 | + input_ = input.contiguous() |
| 108 | + weight_ = weight.contiguous() |
| 109 | + output, invvar = fused_layer_norm_cuda.rms_forward_affine( |
| 110 | + input_, normalized_shape, weight_, eps |
| 111 | + ) |
| 112 | + return output, invvar |
| 113 | + |
| 114 | + |
| 115 | +@fused_rms_norm_affine_fwd.register_fake |
| 116 | +def fused_rms_norm_affine_fwd_fake( |
| 117 | + input: torch.Tensor, |
| 118 | + weight: torch.Tensor, |
| 119 | + normalized_shape: List[int], |
| 120 | + eps: float, |
| 121 | + memory_efficient: bool = False, |
| 122 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 123 | + input = input.contiguous() |
| 124 | + weight = weight.contiguous() |
| 125 | + idiff = input.ndim - len(normalized_shape) |
| 126 | + n = 1 |
| 127 | + for i in range(idiff): |
| 128 | + n *= input.shape[i] |
| 129 | + if input.dtype in [torch.float16, torch.bfloat16]: |
| 130 | + dtype = torch.float32 |
| 131 | + else: |
| 132 | + dtype = input.dtype |
| 133 | + return ( |
| 134 | + torch.empty_like(input), |
| 135 | + torch.empty( |
| 136 | + [n], |
| 137 | + dtype=dtype, |
| 138 | + device=input.device, |
| 139 | + requires_grad=input.requires_grad, |
| 140 | + memory_format=torch.contiguous_format, |
| 141 | + ), |
| 142 | + ) |
| 143 | + |
| 144 | + |
| 145 | +@torch.library.custom_op("apex::fused_rms_norm_affine_bwd", mutates_args=()) |
| 146 | +def fused_rms_norm_affine_bwd( |
| 147 | + grad_output: torch.Tensor, |
| 148 | + invvar: torch.Tensor, |
| 149 | + input_or_output: torch.Tensor, |
| 150 | + normalized_shape: List[int], |
| 151 | + weight: torch.Tensor, |
| 152 | + eps: float, |
| 153 | + memory_efficient: bool = False, |
| 154 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 155 | + grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( |
| 156 | + grad_output.contiguous(), |
| 157 | + invvar, |
| 158 | + input_or_output, |
| 159 | + normalized_shape, |
| 160 | + weight, |
| 161 | + eps, |
| 162 | + memory_efficient, |
| 163 | + ) |
| 164 | + return grad_input, grad_weight |
| 165 | + |
| 166 | + |
| 167 | +@fused_rms_norm_affine_bwd.register_fake |
| 168 | +def fused_rms_norm_affine_bwd_fake( |
| 169 | + grad_output: torch.Tensor, |
| 170 | + invvar: torch.Tensor, |
| 171 | + input_or_output: torch.Tensor, |
| 172 | + normalized_shape: List[int], |
| 173 | + weight: torch.Tensor, |
| 174 | + eps: float, |
| 175 | + memory_efficient: bool = False, |
| 176 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 177 | + grad_input = torch.empty_like(input_or_output) |
| 178 | + grad_weight = torch.empty_like(weight) |
| 179 | + return grad_input, grad_weight |
| 180 | + |
| 181 | + |
| 182 | +def backward(ctx, grad_output, grad_invvar): |
| 183 | + input_or_output, weight_, invvar = ctx.saved_tensors |
| 184 | + grad_input = grad_weight = None |
| 185 | + grad_input, grad_weight = fused_rms_norm_affine_bwd( |
| 186 | + grad_output, |
| 187 | + invvar, |
| 188 | + input_or_output, |
| 189 | + ctx.normalized_shape, |
| 190 | + weight_, |
| 191 | + ctx.eps, |
| 192 | + ctx.memory_efficient, |
| 193 | + ) |
| 194 | + return grad_input, grad_weight, None, None, None |
| 195 | + |
| 196 | + |
| 197 | +def setup_context(ctx, inputs, output): |
| 198 | + input_, weight_, normalized_shape, eps, memory_efficient = inputs |
| 199 | + output_, invvar = output |
| 200 | + input_ = input_.contiguous() |
| 201 | + weight_ = weight_.contiguous() |
| 202 | + if memory_efficient: |
| 203 | + ctx.save_for_backward(output_, weight_, invvar) |
| 204 | + else: |
| 205 | + ctx.save_for_backward(input_, weight_, invvar) |
| 206 | + ctx.normalized_shape = normalized_shape |
| 207 | + ctx.eps = eps |
| 208 | + ctx.memory_efficient = memory_efficient |
| 209 | + |
| 210 | + |
| 211 | +fused_rms_norm_affine_fwd.register_autograd(backward, setup_context=setup_context) |
| 212 | + |
| 213 | + |
94 | 214 | class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): |
95 | 215 |
|
96 | 216 | @staticmethod |
@@ -212,7 +332,7 @@ def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, e |
212 | 332 | def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6, memory_efficient=False): |
213 | 333 | args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps, memory_efficient) |
214 | 334 | with torch.amp.autocast('cuda', enabled=False): |
215 | | - return FusedRMSNormAffineFunction.apply(*args) |
| 335 | + return fused_rms_norm_affine_fwd(*args)[0] |
216 | 336 |
|
217 | 337 |
|
218 | 338 | def fused_rms_norm(input, normalized_shape, eps=1e-6, memory_efficient=False): |
|
0 commit comments