From bd505de0c32841b3377217917c1e56ecf74e347e Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 9 Jan 2018 14:40:11 -0500 Subject: [PATCH 1/4] FIX: Fix einsum optimize logic for singleton dimensions --- numpy/core/einsumfunc.py | 10 ++++++---- numpy/core/tests/test_einsum.py | 8 ++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index 9ad0c7e1788d..fdbffd11b6d4 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -706,10 +706,12 @@ def einsum_path(*operands, **kwargs): for cnum, char in enumerate(term): dim = sh[cnum] if char in dimension_dict.keys(): - if dimension_dict[char] != dim: - raise ValueError("Size of label '%s' for operand %d does " - "not match previous terms." - % (char, tnum)) + if dimension_dict[char] == 1: + dimension_dict[char] = dim + elif dim not in (1, dimension_dict[char]): + raise ValueError("Size of label '%s' for operand %d (%d) " + "does not match previous terms (%d)." + % (char, tnum, dimension_dict[char], dim)) else: dimension_dict[char] = dim diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index d07256ec291a..663841cd4fc6 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -481,6 +481,14 @@ def check_einsum_sums(self, dtype, do_opt=False): r = np.arange(4).reshape(2, 2) + 7 assert_equal(np.einsum('z,mz,zm->', p, q, r), 253) + # singleton dimensions broadcast (gh-10343) + p = np.ones((10,2)) + q = np.ones((1,2)) + assert_array_equal(np.einsum('ti,ti->i', p, q, optimize=True), + np.einsum('ti,ti->i', p, q, optimize=False)) + assert_array_equal(np.einsum('ti,ti->i', p, q, optimize=True), + [10.] * 2) + def test_einsum_sums_int8(self): self.check_einsum_sums('i1') From 01e313e4855172c0bc7d1f3f3aab5cb5f1121f2b Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Sun, 14 Jan 2018 01:52:26 -0500 Subject: [PATCH 2/4] ENH: Add broadcasting test --- numpy/core/tests/test_einsum.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index 663841cd4fc6..9bd85fdb9947 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -484,11 +484,22 @@ def check_einsum_sums(self, dtype, do_opt=False): # singleton dimensions broadcast (gh-10343) p = np.ones((10,2)) q = np.ones((1,2)) - assert_array_equal(np.einsum('ti,ti->i', p, q, optimize=True), - np.einsum('ti,ti->i', p, q, optimize=False)) - assert_array_equal(np.einsum('ti,ti->i', p, q, optimize=True), + assert_array_equal(np.einsum('ij,ij->j', p, q, optimize=True), + np.einsum('ij,ij->j', p, q, optimize=False)) + assert_array_equal(np.einsum('ij,ij->j', p, q, optimize=True), [10.] * 2) + p = np.ones((1, 5)) + q = np.ones((5, 5)) + for optimize in (True, False): + assert_array_equal(np.einsum("...ij,...jk->...ik", p, p, + optimize=optimize), + np.einsum("...ij,...jk->...ik", p, q, + optimize=optimize)) + assert_array_equal(np.einsum("...ij,...jk->...ik", p, q, + optimize=optimize), + np.full((1, 5), 5)) + def test_einsum_sums_int8(self): self.check_einsum_sums('i1') From 2afc7d5c8e474f06d02e2a137cdda81237bfb38d Mon Sep 17 00:00:00 2001 From: Daniel Smith Date: Wed, 17 Jan 2018 16:20:01 -0500 Subject: [PATCH 3/4] Patches up broadcasting einsum issues for BLAS cases --- numpy/core/einsumfunc.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index fdbffd11b6d4..50ec8d930c68 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -706,6 +706,8 @@ def einsum_path(*operands, **kwargs): for cnum, char in enumerate(term): dim = sh[cnum] if char in dimension_dict.keys(): + + # For broadcasting cases we always want the largest dim size if dimension_dict[char] == 1: dimension_dict[char] = dim elif dim not in (1, dimension_dict[char]): @@ -1103,6 +1105,22 @@ def einsum(*operands, **kwargs): if specified_out and ((num + 1) == len(contraction_list)): handle_out = True + # Handle broadcasting vs BLAS cases + if blas and ((1 in tmp_operands[0]) or (1 in tmp_operands[1])): + + # Checks have already been handled + input_str, results_index = einsum_str.split('->') + input_left, input_right = input_str.split(',') + + left_dims = {dim : size for dim, size in zip(input_left, tmp_operands[0].shape)} + right_dims = {dim : size for dim, size in zip(input_right, tmp_operands[1].shape)} + + # If dims do not match we are broadcasting, BLAS off + for ind in idx_rm: + if left_dims[ind] != right_dims[ind]: + blas = False + break + # Call tensordot if blas: From ab3e91c1f9fcacb3864be9f967d17298843c8e26 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 18 Jan 2018 11:14:18 -0500 Subject: [PATCH 4/4] FIX: Deduplicate code --- numpy/core/einsumfunc.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index 50ec8d930c68..da78748a3501 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -706,7 +706,6 @@ def einsum_path(*operands, **kwargs): for cnum, char in enumerate(term): dim = sh[cnum] if char in dimension_dict.keys(): - # For broadcasting cases we always want the largest dim size if dimension_dict[char] == 1: dimension_dict[char] = dim @@ -1106,28 +1105,21 @@ def einsum(*operands, **kwargs): handle_out = True # Handle broadcasting vs BLAS cases - if blas and ((1 in tmp_operands[0]) or (1 in tmp_operands[1])): - + if blas: # Checks have already been handled input_str, results_index = einsum_str.split('->') input_left, input_right = input_str.split(',') - - left_dims = {dim : size for dim, size in zip(input_left, tmp_operands[0].shape)} - right_dims = {dim : size for dim, size in zip(input_right, tmp_operands[1].shape)} - - # If dims do not match we are broadcasting, BLAS off - for ind in idx_rm: - if left_dims[ind] != right_dims[ind]: + if 1 in tmp_operands[0] or 1 in tmp_operands[1]: + left_dims = {dim: size for dim, size in + zip(input_left, tmp_operands[0].shape)} + right_dims = {dim: size for dim, size in + zip(input_right, tmp_operands[1].shape)} + # If dims do not match we are broadcasting, BLAS off + if any(left_dims[ind] != right_dims[ind] for ind in idx_rm): blas = False - break - # Call tensordot + # Call tensordot if still possible if blas: - - # Checks have already been handled - input_str, results_index = einsum_str.split('->') - input_left, input_right = input_str.split(',') - tensor_result = input_left + input_right for s in idx_rm: tensor_result = tensor_result.replace(s, "")