Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0cb2409

Browse files
gchananfacebook-github-bot
authored andcommitted
Handle non-contiguous inputs with mkldnn convolution. (#16300)
Summary: Fixes #16018. Backwards appears to be fine because the derivative is written in terms of mkldnn_convolution itself. Pull Request resolved: #16300 Differential Revision: D13797776 Pulled By: gchanan fbshipit-source-id: 68a990b8a3c186412a99d176931314806c9ed7bf
1 parent 45c3cc9 commit 0cb2409

2 files changed

Lines changed: 35 additions & 2 deletions

File tree

aten/src/ATen/native/Convolution.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,8 @@ at::Tensor _convolution(
399399
AT_CHECK(!bias.defined() || (input.type() == bias.type()),
400400
"Input type (", input.type().toString(), ") and bias type (", bias.type().toString(),
401401
") should be the same");
402-
403-
output = at::mkldnn_convolution(input, weight, bias, params.padding, params.stride, params.dilation, params.groups);
402+
output = at::mkldnn_convolution(input, weight.contiguous(), bias.defined() ? bias.contiguous() : bias,
403+
params.padding, params.stride, params.dilation, params.groups);
404404
#endif
405405
} else {
406406
if (params.groups == 1) {

test/test_nn.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6591,6 +6591,39 @@ def test_conv_noncontig_weights(self):
65916591
def test_conv_noncontig_weights_cuda(self):
65926592
self._test_conv_noncontig_weights(self, torch.device('cuda'))
65936593

6594+
@staticmethod
6595+
def _test_conv_noncontig_weights_and_bias(self, device):
6596+
# need floats to exercise https://github.com/pytorch/pytorch/issues/16018
6597+
for bias in [True, False]:
6598+
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
6599+
bias=bias).to(device, torch.float)
6600+
6601+
input_nc = torch.randn((1, 3, 224, 224, 2), device=device, dtype=torch.float)[:, :, :, :, 1]
6602+
input_c = input_nc.contiguous()
6603+
6604+
weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[:, :, :, :, 1]
6605+
conv1.weight = nn.Parameter(weight_nc)
6606+
weight_c = conv1.weight.contiguous()
6607+
6608+
if bias:
6609+
bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1]
6610+
conv1.bias = nn.Parameter(bias_nc)
6611+
bias_c = conv1.bias.contiguous()
6612+
6613+
out1 = conv1(input_nc)
6614+
conv1.weight = nn.Parameter(weight_c)
6615+
if bias:
6616+
conv1.bias = nn.Parameter(bias_c)
6617+
out2 = conv1(input_c)
6618+
self.assertEqual(out1, out2)
6619+
6620+
def test_conv_noncontig_weights_and_bias(self):
6621+
self._test_conv_noncontig_weights_and_bias(self, torch.device('cpu'))
6622+
6623+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
6624+
def test_conv_noncontig_weights_and_bias_cuda(self):
6625+
self._test_conv_noncontig_weights_and_bias(self, torch.device('cuda'))
6626+
65946627
def run_conv_double_back_test(self, kern, stride, padding, chan_in, chan_out, batch_size,
65956628
inp_size, dilation, no_weight, groups=1, use_cuda=False,
65966629
use_bias=True, dtype=torch.double):

0 commit comments

Comments
 (0)