Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f290f3c

Browse files
committed
Ignore updateGradInput if self.gradInput is nil
Change suggested by Natalia Gimelshein To align behaviour on other modules (linear, spatial convolution, etc.). This allows model definitions such as: local lenet = nn.Sequential() lenet:add(nn.MulConstant(0.00390625)) lenet:add(nn.SpatialConvolution(1,20,5,5,1,1,0)) -- 1*28*28 -> 20*24*24 lenet:add(nn.SpatialMaxPooling(2, 2, 2, 2)) -- 20*24*24 -> 20*12*12 lenet:add(nn.SpatialConvolution(20,50,5,5,1,1,0)) -- 20*12*12 -> 50*8*8 lenet:add(nn.SpatialMaxPooling(2,2,2,2)) -- 50*8*8 -> 50*4*4 lenet:add(nn.View(-1):setNumInputDims(3)) -- 50*4*4 -> 800 lenet:add(nn.Linear(800,500)) -- 800 -> 500 lenet:add(nn.ReLU()) lenet:add(nn.Linear(500, 10)) -- 500 -> 10 lenet:add(nn.LogSoftMax()) lenet:get(1).gradInput = nil lenet:get(2).gradInput = nil Setting gradInput to nil on the first two layers removes unnecessary dgrad computations and saves about 5% of compute utilization.
1 parent 42ef6c4 commit f290f3c

1 file changed

Lines changed: 14 additions & 12 deletions

File tree

MulConstant.lua

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function MulConstant:__init(constant_scalar,ip)
44
parent.__init(self)
55
assert(type(constant_scalar) == 'number', 'input is not scalar!')
66
self.constant_scalar = constant_scalar
7-
7+
88
-- default for inplace is false
99
self.inplace = ip or false
1010
if (ip and type(ip) ~= 'boolean') then
@@ -22,18 +22,20 @@ function MulConstant:updateOutput(input)
2222
self.output:mul(self.constant_scalar)
2323
end
2424
return self.output
25-
end
25+
end
2626

2727
function MulConstant:updateGradInput(input, gradOutput)
28-
if self.inplace then
29-
gradOutput:mul(self.constant_scalar)
30-
self.gradInput = gradOutput
31-
-- restore previous input value
32-
input:div(self.constant_scalar)
33-
else
34-
self.gradInput:resizeAs(gradOutput)
35-
self.gradInput:copy(gradOutput)
36-
self.gradInput:mul(self.constant_scalar)
28+
if self.gradInput then
29+
if self.inplace then
30+
gradOutput:mul(self.constant_scalar)
31+
self.gradInput = gradOutput
32+
-- restore previous input value
33+
input:div(self.constant_scalar)
34+
else
35+
self.gradInput:resizeAs(gradOutput)
36+
self.gradInput:copy(gradOutput)
37+
self.gradInput:mul(self.constant_scalar)
38+
end
39+
return self.gradInput
3740
end
38-
return self.gradInput
3941
end

0 commit comments

Comments
 (0)