1- import math
2- import torch
1+ import importlib
32import numbers
3+
4+ import torch
45from torch .nn .parameter import Parameter
56from torch .nn import init
67from torch .nn import functional as F
7- import importlib
8+
9+ from apex ._autocast_utils import _cast_if_autocast_enabled
810
911global fused_layer_norm_cuda
1012fused_layer_norm_cuda = None
1113
14+
1215class FusedLayerNormAffineFunction (torch .autograd .Function ):
16+ @staticmethod
17+ def forward (ctx , input , weight , bias , normalized_shape , eps ):
18+ global fused_layer_norm_cuda
19+ if fused_layer_norm_cuda is None :
20+ fused_layer_norm_cuda = importlib .import_module ("fused_layer_norm_cuda" )
21+ ctx .normalized_shape = normalized_shape
22+ ctx .eps = eps
23+ input_ = input .contiguous ()
24+ weight_ = weight .contiguous ()
25+ bias_ = bias .contiguous ()
26+ output , mean , invvar = fused_layer_norm_cuda .forward_affine (
27+ input_ , ctx .normalized_shape , weight_ , bias_ , ctx .eps
28+ )
29+ ctx .save_for_backward (input_ , weight_ , bias_ , mean , invvar )
30+ return output
31+
32+ @staticmethod
33+ def backward (ctx , grad_output ):
34+ input_ , weight_ , bias_ , mean , invvar = ctx .saved_tensors
35+ grad_input = grad_weight = grad_bias = None
36+ grad_input , grad_weight , grad_bias = fused_layer_norm_cuda .backward_affine (
37+ grad_output .contiguous (), mean , invvar , input_ , ctx .normalized_shape , weight_ , bias_ , ctx .eps
38+ )
39+ return grad_input , grad_weight , grad_bias , None , None
40+
41+
42+ class FusedLayerNormAffineMixedDtypesFunction (FusedLayerNormAffineFunction ):
43+
44+ @staticmethod
45+ def forward (ctx , input , weight , bias , normalized_shape , eps ):
46+ global fused_layer_norm_cuda
47+ if fused_layer_norm_cuda is None :
48+ fused_layer_norm_cuda = importlib .import_module ("fused_layer_norm_cuda" )
49+ ctx .normalized_shape = normalized_shape
50+ ctx .eps = eps
51+ input_ = input .contiguous ()
52+ weight_ = weight .contiguous ()
53+ bias_ = bias .contiguous ()
54+ output , mean , invvar = fused_layer_norm_cuda .forward_affine_mixed_dtypes (
55+ input_ , ctx .normalized_shape , weight_ , bias_ , ctx .eps
56+ )
57+ ctx .save_for_backward (input_ , weight_ , bias_ , mean , invvar )
58+ return output
1359
14- @staticmethod
15- def forward (ctx , input , weight , bias , normalized_shape , eps ):
16- global fused_layer_norm_cuda
17- if fused_layer_norm_cuda is None :
18- fused_layer_norm_cuda = importlib .import_module ("fused_layer_norm_cuda" )
19- ctx .normalized_shape = normalized_shape
20- ctx .eps = eps
21- input_ = input .contiguous ()
22- weight_ = weight .contiguous ()
23- bias_ = bias .contiguous ()
24- output , mean , invvar = fused_layer_norm_cuda .forward_affine (
25- input_ , ctx .normalized_shape , weight_ , bias_ , ctx .eps )
26- ctx .save_for_backward (input_ , weight_ , bias_ , mean , invvar )
27- return output
28-
29- @staticmethod
30- def backward (ctx , grad_output ):
31- input_ , weight_ , bias_ , mean , invvar = ctx .saved_tensors
32- grad_input = grad_weight = grad_bias = None
33- grad_input , grad_weight , grad_bias = fused_layer_norm_cuda .backward_affine (
34- grad_output .contiguous (), mean , invvar ,
35- input_ , ctx .normalized_shape ,
36- weight_ , bias_ , ctx .eps )
37- return grad_input , grad_weight , grad_bias , None , None
3860
3961class FusedLayerNormFunction (torch .autograd .Function ):
62+ @staticmethod
63+ def forward (ctx , input , normalized_shape , eps ):
64+ global fused_layer_norm_cuda
65+ if fused_layer_norm_cuda is None :
66+ fused_layer_norm_cuda = importlib .import_module ("fused_layer_norm_cuda" )
67+ ctx .normalized_shape = normalized_shape
68+ ctx .eps = eps
69+ input_ = input .contiguous ()
70+ output , mean , invvar = fused_layer_norm_cuda .forward (input_ , ctx .normalized_shape , ctx .eps )
71+ ctx .save_for_backward (input_ , mean , invvar )
72+ return output
73+
74+ @staticmethod
75+ def backward (ctx , grad_output ):
76+ input_ , mean , invvar = ctx .saved_tensors
77+ grad_input = None
78+ grad_input = fused_layer_norm_cuda .backward (
79+ grad_output .contiguous (), mean , invvar , input_ , ctx .normalized_shape , ctx .eps
80+ )
81+ return grad_input , None , None
82+
83+
84+ def fused_layer_norm_affine (input , weight , bias , normalized_shape , eps = 1e-6 ):
85+ args = _cast_if_autocast_enabled (input , weight , bias , normalized_shape , eps )
86+ with torch .cuda .amp .autocast (enabled = False ):
87+ return FusedLayerNormAffineFunction .apply (* args )
4088
41- @staticmethod
42- def forward (ctx , input , normalized_shape , eps ):
43- global fused_layer_norm_cuda
44- if fused_layer_norm_cuda is None :
45- fused_layer_norm_cuda = importlib .import_module ("fused_layer_norm_cuda" )
46- ctx .normalized_shape = normalized_shape
47- ctx .eps = eps
48- input_ = input .contiguous ()
49- output , mean , invvar = fused_layer_norm_cuda .forward (
50- input_ , ctx .normalized_shape , ctx .eps )
51- ctx .save_for_backward (input_ , mean , invvar )
52- return output
53-
54- @staticmethod
55- def backward (ctx , grad_output ):
56- input_ , mean , invvar = ctx .saved_tensors
57- grad_input = None
58- grad_input = fused_layer_norm_cuda .backward (
59- grad_output .contiguous (), mean , invvar ,
60- input_ , ctx .normalized_shape ,
61- ctx .eps )
62- return grad_input , None , None
63-
64- def fused_layer_norm_affine (input , normalized_shape , weight , bias , eps = 1e-6 ):
65- return FusedLayerNormAffineFunction .apply (input , weight , bias , normalized_shape , eps )
6689
6790def fused_layer_norm (input , normalized_shape , eps = 1e-6 ):
68- return FusedLayerNormFunction .apply (input , normalized_shape , eps )
91+ args = _cast_if_autocast_enabled (input , normalized_shape , eps )
92+ with torch .cuda .amp .autocast (enabled = False ):
93+ return FusedLayerNormFunction .apply (* args )
94+
95+
96+ def mixed_dtype_fused_layer_norm_affine (input , weight , bias , normalized_shape , eps = 1e-6 ):
97+ args = _cast_if_autocast_enabled (input , weight , bias , normalized_shape , eps )
98+ with torch .cuda .amp .autocast (enabled = False ):
99+ return FusedLayerNormAffineMixedDtypesFunction .apply (* args )
100+
69101
70102class FusedLayerNorm (torch .nn .Module ):
71103 r"""Applies Layer Normalization over a mini-batch of inputs as described in
@@ -126,8 +158,9 @@ class FusedLayerNorm(torch.nn.Module):
126158
127159 .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
128160 """
161+
129162 def __init__ (self , normalized_shape , eps = 1e-5 , elementwise_affine = True ):
130- super (FusedLayerNorm , self ).__init__ ()
163+ super ().__init__ ()
131164
132165 global fused_layer_norm_cuda
133166 fused_layer_norm_cuda = importlib .import_module ("fused_layer_norm_cuda" )
@@ -141,8 +174,8 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
141174 self .weight = Parameter (torch .Tensor (* normalized_shape ))
142175 self .bias = Parameter (torch .Tensor (* normalized_shape ))
143176 else :
144- self .register_parameter (' weight' , None )
145- self .register_parameter (' bias' , None )
177+ self .register_parameter (" weight" , None )
178+ self .register_parameter (" bias" , None )
146179 self .reset_parameters ()
147180
148181 def reset_parameters (self ):
@@ -152,14 +185,34 @@ def reset_parameters(self):
152185
153186 def forward (self , input ):
154187 if not input .is_cuda :
155- return F .layer_norm (
156- input , self .normalized_shape , self .weight , self .bias , self .eps )
188+ return F .layer_norm (input , self .normalized_shape , self .weight , self .bias , self .eps )
157189 if self .elementwise_affine :
158- return FusedLayerNormAffineFunction .apply (
159- input , self .weight , self .bias , self .normalized_shape ,self .eps )
190+ return fused_layer_norm_affine (input , self .weight , self .bias , self .normalized_shape , self .eps )
160191 else :
161- return FusedLayerNormFunction . apply (input , self .normalized_shape , self .eps )
192+ return fused_layer_norm (input , self .normalized_shape , self .eps )
162193
163194 def extra_repr (self ):
164- return '{normalized_shape}, eps={eps}, ' \
165- 'elementwise_affine={elementwise_affine}' .format (** self .__dict__ )
195+ return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}" .format (** self .__dict__ )
196+
197+
198+ # NOTE (mkozuki): Why "mixed"?
199+ # MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype
200+ # as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
201+ # See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp"
202+ class MixedFusedLayerNorm (FusedLayerNorm ):
203+
204+ def __init__ (self , normalized_shape , eps = 1e-5 , ** kwargs ):
205+ if "elementwise_affine" in kwargs :
206+ import warnings
207+ warnings .warn ("MixedFusedLayerNorm does not support `elementwise_affine` argument" )
208+ elementwise_affine = kwargs .pop ("elementwise_affine" )
209+ if not elementwise_affine :
210+ raise RuntimeError ("MixedFusedLayerNorm does not support `elementwise_affine = False`" )
211+
212+ super ().__init__ (normalized_shape = normalized_shape , eps = eps , elementwise_affine = True )
213+
214+ def forward (self , input : torch .Tensor ):
215+ # NOTE (mkozuki): CPU path is here mainly for unittest sake.
216+ if not input .is_cuda :
217+ return F .layer_norm (input , self .normalized_shape , self .weight , self .bias , self .eps )
218+ return mixed_dtype_fused_layer_norm_affine (input , self .weight , self .bias , self .normalized_shape , self .eps )
0 commit comments