1+ import unittest
2+ import os
3+ import random
4+
5+ import torch
6+ import apex
7+
8+
9+ class TestFusedLayerNorm (unittest .TestCase ):
10+ def setUp (self ):
11+ self .module = apex .normalization .FusedLayerNorm (normalized_shape = [32 , 64 ], elementwise_affine = False )
12+ self .input_ = torch .randn (16 , 32 , 64 )
13+ torch .cuda .manual_seed (42 )
14+
15+ def forward_cpu (self , input_ ):
16+ self .module .cpu ()
17+ return self .module (input_ .cpu ())
18+
19+ def forward_cuda (self , input_ ):
20+ self .module .cuda ()
21+ return self .module (input_ .cuda ())
22+
23+ def test_forward_cuda (self ):
24+ out_ = self .forward_cuda (self .input_ )
25+ assert out_ .is_cuda == True
26+
27+ def test_forward_cpu (self ):
28+ out_ = self .forward_cpu (self .input_ )
29+ assert out_ .is_cuda == False
30+
31+ def test_same_output (self ):
32+ out_cpu = self .forward_cpu (self .input_ )
33+ out_cuda = self .forward_cuda (self .input_ )
34+ torch .testing .assert_allclose (out_cpu , out_cuda .cpu ())
35+
36+
37+ class TestFusedLayerNormElemWise (TestFusedLayerNorm ):
38+ def setUp (self ):
39+ self .module = apex .normalization .FusedLayerNorm (normalized_shape = [32 , 64 ], elementwise_affine = True )
40+ self .input_ = torch .randn (16 , 32 , 64 )
41+ torch .cuda .manual_seed (42 )
0 commit comments