Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d5b1e06

Browse files
committed
Fix for non-contiguous grad_output in cuDNN conv
1 parent 6c77fa9 commit d5b1e06

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

test/test_nn.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,6 +1274,22 @@ def test_inplace_thnn(self):
12741274
output.backward(grad_output)
12751275
self.assertEqual(grad_output, grad_output_clone)
12761276

1277+
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
1278+
def test_noncontig_conv_grad(self):
1279+
# FIXME: remove after adding non-contiguous grad tests for all modules
1280+
module = nn.Conv2d(3, 5, kernel_size=3, padding=1).cuda()
1281+
input = Variable(torch.randn(2, 3, 10, 10).cuda(), requires_grad=True)
1282+
output = module(input)
1283+
1284+
grad = torch.randn(2, 2, 5, 10, 10).cuda()[:, 1]
1285+
assert not grad.is_contiguous()
1286+
output.backward(grad, retain_variables=True)
1287+
result = output.grad.data.clone()
1288+
output.grad.data.zero_()
1289+
1290+
output.backward(grad.contiguous())
1291+
self.assertEqual(result, output.grad.data)
1292+
12771293
def test_pixel_shuffle(self):
12781294
batch_size = random.randint(1, 3)
12791295
upscale_factor = random.randint(2, 5)

torch/nn/_functions/conv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def forward(self, input, weight, bias=None):
3838

3939
def backward(self, grad_output):
4040
k = grad_output.dim()
41+
grad_output = grad_output.contiguous()
4142
input, weight, bias = self.saved_tensors
4243
if k == 3:
4344
grad_output, input, weight = _view4d(grad_output, input, weight)

0 commit comments

Comments
 (0)