Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ac8214e

Browse files
authored
Traceable RMSNorm (#1861)
1 parent a1df804 commit ac8214e

2 files changed

Lines changed: 144 additions & 1 deletion

File tree

apex/normalization/fused_layer_norm.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch.nn.parameter import Parameter
66
from torch.nn import init
77
from torch.nn import functional as F
8+
from typing import List, Tuple
89

910
from apex._autocast_utils import _cast_if_autocast_enabled
1011

@@ -91,6 +92,125 @@ def backward(ctx, grad_output):
9192
return grad_input, grad_weight, None, None, None
9293

9394

95+
@torch.library.custom_op("apex::fused_rms_norm_affine_fwd", mutates_args=())
96+
def fused_rms_norm_affine_fwd(
97+
input: torch.Tensor,
98+
weight: torch.Tensor,
99+
normalized_shape: List[int],
100+
eps: float,
101+
memory_efficient: bool = False,
102+
) -> Tuple[torch.Tensor, torch.Tensor]:
103+
global fused_layer_norm_cuda
104+
if fused_layer_norm_cuda is None:
105+
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
106+
107+
input_ = input.contiguous()
108+
weight_ = weight.contiguous()
109+
output, invvar = fused_layer_norm_cuda.rms_forward_affine(
110+
input_, normalized_shape, weight_, eps
111+
)
112+
return output, invvar
113+
114+
115+
@fused_rms_norm_affine_fwd.register_fake
116+
def fused_rms_norm_affine_fwd_fake(
117+
input: torch.Tensor,
118+
weight: torch.Tensor,
119+
normalized_shape: List[int],
120+
eps: float,
121+
memory_efficient: bool = False,
122+
) -> Tuple[torch.Tensor, torch.Tensor]:
123+
input = input.contiguous()
124+
weight = weight.contiguous()
125+
idiff = input.ndim - len(normalized_shape)
126+
n = 1
127+
for i in range(idiff):
128+
n *= input.shape[i]
129+
if input.dtype in [torch.float16, torch.bfloat16]:
130+
dtype = torch.float32
131+
else:
132+
dtype = input.dtype
133+
return (
134+
torch.empty_like(input),
135+
torch.empty(
136+
[n],
137+
dtype=dtype,
138+
device=input.device,
139+
requires_grad=input.requires_grad,
140+
memory_format=torch.contiguous_format,
141+
),
142+
)
143+
144+
145+
@torch.library.custom_op("apex::fused_rms_norm_affine_bwd", mutates_args=())
146+
def fused_rms_norm_affine_bwd(
147+
grad_output: torch.Tensor,
148+
invvar: torch.Tensor,
149+
input_or_output: torch.Tensor,
150+
normalized_shape: List[int],
151+
weight: torch.Tensor,
152+
eps: float,
153+
memory_efficient: bool = False,
154+
) -> Tuple[torch.Tensor, torch.Tensor]:
155+
grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine(
156+
grad_output.contiguous(),
157+
invvar,
158+
input_or_output,
159+
normalized_shape,
160+
weight,
161+
eps,
162+
memory_efficient,
163+
)
164+
return grad_input, grad_weight
165+
166+
167+
@fused_rms_norm_affine_bwd.register_fake
168+
def fused_rms_norm_affine_bwd_fake(
169+
grad_output: torch.Tensor,
170+
invvar: torch.Tensor,
171+
input_or_output: torch.Tensor,
172+
normalized_shape: List[int],
173+
weight: torch.Tensor,
174+
eps: float,
175+
memory_efficient: bool = False,
176+
) -> Tuple[torch.Tensor, torch.Tensor]:
177+
grad_input = torch.empty_like(input_or_output)
178+
grad_weight = torch.empty_like(weight)
179+
return grad_input, grad_weight
180+
181+
182+
def backward(ctx, grad_output, grad_invvar):
183+
input_or_output, weight_, invvar = ctx.saved_tensors
184+
grad_input = grad_weight = None
185+
grad_input, grad_weight = fused_rms_norm_affine_bwd(
186+
grad_output,
187+
invvar,
188+
input_or_output,
189+
ctx.normalized_shape,
190+
weight_,
191+
ctx.eps,
192+
ctx.memory_efficient,
193+
)
194+
return grad_input, grad_weight, None, None, None
195+
196+
197+
def setup_context(ctx, inputs, output):
198+
input_, weight_, normalized_shape, eps, memory_efficient = inputs
199+
output_, invvar = output
200+
input_ = input_.contiguous()
201+
weight_ = weight_.contiguous()
202+
if memory_efficient:
203+
ctx.save_for_backward(output_, weight_, invvar)
204+
else:
205+
ctx.save_for_backward(input_, weight_, invvar)
206+
ctx.normalized_shape = normalized_shape
207+
ctx.eps = eps
208+
ctx.memory_efficient = memory_efficient
209+
210+
211+
fused_rms_norm_affine_fwd.register_autograd(backward, setup_context=setup_context)
212+
213+
94214
class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction):
95215

96216
@staticmethod
@@ -212,7 +332,7 @@ def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, e
212332
def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6, memory_efficient=False):
213333
args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps, memory_efficient)
214334
with torch.amp.autocast('cuda', enabled=False):
215-
return FusedRMSNormAffineFunction.apply(*args)
335+
return fused_rms_norm_affine_fwd(*args)[0]
216336

217337

218338
def fused_rms_norm(input, normalized_shape, eps=1e-6, memory_efficient=False):

tests/L0/run_fused_layer_norm/test_fused_layer_norm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,29 @@ def test_layer_norm_export(self):
309309
native_x, fused_x = _prep_inputs(batch_size, normalized_shape, torch.float32)
310310
self._verify_export(fused, fused_x)
311311
self._verify_export(fused_m, fused_x)
312+
313+
def test_compile_fused_rms_norm(self):
314+
batch_size = 16
315+
normalized_shape = [32, 16]
316+
eager_mod = FusedRMSNorm(
317+
normalized_shape=normalized_shape, elementwise_affine=True
318+
).cuda()
319+
compiled_mod = torch.compile(fullgraph=True)(eager_mod)
320+
input_shape = [batch_size] + normalized_shape
321+
eager_x = torch.randn(input_shape, device="cuda").requires_grad_(True)
322+
compiled_x = eager_x.detach().clone().requires_grad_(True)
323+
324+
expected = eager_mod(eager_x)
325+
actual = compiled_mod(compiled_x)
326+
torch.testing.assert_close(actual, expected.detach())
327+
328+
g_eager = torch.rand_like(expected)
329+
with torch.no_grad():
330+
g_compiled = g_eager.detach().clone()
331+
expected.backward(g_eager)
332+
actual.backward(g_compiled)
333+
334+
torch.testing.assert_close(eager_x.grad, compiled_x.grad)
312335

313336
instantiate_device_type_tests(TestFusedLayerNorm, globals(), only_for=("cuda",))
314337
if __name__ == "__main__":

0 commit comments

Comments
 (0)