Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a786ca0

Browse files
authored
fix and generate docs for FusedRMSNorm (#1285)
1 parent 684c473 commit a786ca0

2 files changed

Lines changed: 9 additions & 6 deletions

File tree

apex/normalization/fused_layer_norm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,19 +303,19 @@ class FusedRMSNorm(torch.nn.Module):
303303
Currently only runs on cuda() tensors.
304304
305305
.. math::
306-
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
306+
y = \frac{x}{\mathrm{RMS}[x]} * \gamma
307307
308-
The mean and standard-deviation are calculated separately over the last
308+
The root-mean-square is calculated separately over the last
309309
certain number dimensions which have to be of the shape specified by
310310
:attr:`normalized_shape`.
311-
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
311+
:math:`\gamma` is a learnable affine transform parameter of
312312
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
313313
314314
.. note::
315315
Unlike Batch Normalization and Instance Normalization, which applies
316316
scalar scale and bias for each entire channel/plane with the
317-
:attr:`affine` option, Layer Normalization applies per-element scale and
318-
bias with :attr:`elementwise_affine`.
317+
:attr:`affine` option, RMS Normalization applies per-element scale
318+
with :attr:`elementwise_affine`.
319319
320320
This layer uses statistics computed from input data in both training and
321321
evaluation modes.
@@ -353,7 +353,7 @@ class FusedRMSNorm(torch.nn.Module):
353353
>>> # Activating the module
354354
>>> output = m(input)
355355
356-
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
356+
.. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf
357357
"""
358358

359359
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):

docs/source/layernorm.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@ apex.normalization.fused_layer_norm
1212
1313
.. autoclass:: FusedLayerNorm
1414
:members:
15+
16+
.. autoclass:: FusedRMSNorm
17+
:members:

0 commit comments

Comments
 (0)