Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d130ec1

Browse files
committed
quick fix: make FusedLayerNorm compatible with cpu
1 parent 683b6e0 commit d130ec1

2 files changed

Lines changed: 45 additions & 0 deletions

File tree

apex/normalization/fused_layer_norm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numbers
44
from torch.nn.parameter import Parameter
55
from torch.nn import init
6+
from torch.nn import functional as F
67
import importlib
78

89
class FusedLayerNormAffineFunction(torch.autograd.Function):
@@ -144,6 +145,9 @@ def reset_parameters(self):
144145
init.zeros_(self.bias)
145146

146147
def forward(self, input):
148+
if not input.is_cuda:
149+
return F.layer_norm(
150+
input, self.normalized_shape, self.weight, self.bias, self.eps)
147151
if self.elementwise_affine:
148152
return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)(
149153
input, self.weight, self.bias)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import unittest
2+
import os
3+
import random
4+
5+
import torch
6+
import apex
7+
8+
9+
class TestFusedLayerNorm(unittest.TestCase):
10+
def setUp(self):
11+
self.module = apex.normalization.FusedLayerNorm(normalized_shape=[32, 64], elementwise_affine=False)
12+
self.input_ = torch.randn(16, 32, 64)
13+
torch.cuda.manual_seed(42)
14+
15+
def forward_cpu(self, input_):
16+
self.module.cpu()
17+
return self.module(input_.cpu())
18+
19+
def forward_cuda(self, input_):
20+
self.module.cuda()
21+
return self.module(input_.cuda())
22+
23+
def test_forward_cuda(self):
24+
out_ = self.forward_cuda(self.input_)
25+
assert out_.is_cuda == True
26+
27+
def test_forward_cpu(self):
28+
out_ = self.forward_cpu(self.input_)
29+
assert out_.is_cuda == False
30+
31+
def test_same_output(self):
32+
out_cpu = self.forward_cpu(self.input_)
33+
out_cuda = self.forward_cuda(self.input_)
34+
torch.testing.assert_allclose(out_cpu, out_cuda.cpu())
35+
36+
37+
class TestFusedLayerNormElemWise(TestFusedLayerNorm):
38+
def setUp(self):
39+
self.module = apex.normalization.FusedLayerNorm(normalized_shape=[32, 64], elementwise_affine=True)
40+
self.input_ = torch.randn(16, 32, 64)
41+
torch.cuda.manual_seed(42)

0 commit comments

Comments
 (0)