diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index 9ad0c7e1788d..da78748a3501 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -706,10 +706,13 @@ def einsum_path(*operands, **kwargs): for cnum, char in enumerate(term): dim = sh[cnum] if char in dimension_dict.keys(): - if dimension_dict[char] != dim: - raise ValueError("Size of label '%s' for operand %d does " - "not match previous terms." - % (char, tnum)) + # For broadcasting cases we always want the largest dim size + if dimension_dict[char] == 1: + dimension_dict[char] = dim + elif dim not in (1, dimension_dict[char]): + raise ValueError("Size of label '%s' for operand %d (%d) " + "does not match previous terms (%d)." + % (char, tnum, dimension_dict[char], dim)) else: dimension_dict[char] = dim @@ -1101,13 +1104,22 @@ def einsum(*operands, **kwargs): if specified_out and ((num + 1) == len(contraction_list)): handle_out = True - # Call tensordot + # Handle broadcasting vs BLAS cases if blas: - # Checks have already been handled input_str, results_index = einsum_str.split('->') input_left, input_right = input_str.split(',') - + if 1 in tmp_operands[0] or 1 in tmp_operands[1]: + left_dims = {dim: size for dim, size in + zip(input_left, tmp_operands[0].shape)} + right_dims = {dim: size for dim, size in + zip(input_right, tmp_operands[1].shape)} + # If dims do not match we are broadcasting, BLAS off + if any(left_dims[ind] != right_dims[ind] for ind in idx_rm): + blas = False + + # Call tensordot if still possible + if blas: tensor_result = input_left + input_right for s in idx_rm: tensor_result = tensor_result.replace(s, "") diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index d07256ec291a..9bd85fdb9947 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -481,6 +481,25 @@ def check_einsum_sums(self, dtype, do_opt=False): r = np.arange(4).reshape(2, 2) + 7 assert_equal(np.einsum('z,mz,zm->', p, q, r), 253) + # singleton dimensions broadcast (gh-10343) + p = np.ones((10,2)) + q = np.ones((1,2)) + assert_array_equal(np.einsum('ij,ij->j', p, q, optimize=True), + np.einsum('ij,ij->j', p, q, optimize=False)) + assert_array_equal(np.einsum('ij,ij->j', p, q, optimize=True), + [10.] * 2) + + p = np.ones((1, 5)) + q = np.ones((5, 5)) + for optimize in (True, False): + assert_array_equal(np.einsum("...ij,...jk->...ik", p, p, + optimize=optimize), + np.einsum("...ij,...jk->...ik", p, q, + optimize=optimize)) + assert_array_equal(np.einsum("...ij,...jk->...ik", p, q, + optimize=optimize), + np.full((1, 5), 5)) + def test_einsum_sums_int8(self): self.check_einsum_sums('i1')