diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py index 2fac3caf3762..163f125c28a5 100644 --- a/numpy/core/einsumfunc.py +++ b/numpy/core/einsumfunc.py @@ -4,6 +4,8 @@ """ from __future__ import division, absolute_import, print_function +import itertools + from numpy.compat import basestring from numpy.core.multiarray import c_einsum from numpy.core.numeric import asarray, asanyarray, result_type, tensordot, dot @@ -14,6 +16,44 @@ einsum_symbols_set = set(einsum_symbols) +def _flop_count(idx_contraction, inner, num_terms, size_dictionary): + """ + Computes the number of FLOPS in the contraction. + + Parameters + ---------- + idx_contraction : iterable + The indices involved in the contraction + inner : bool + Does this contraction require an inner product? + num_terms : int + The number of terms in a contraction + size_dictionary : dict + The size of each of the indices in idx_contraction + + Returns + ------- + flop_count : int + The total number of FLOPS required for the contraction. + + Examples + -------- + + >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) + 90 + + >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5}) + 270 + + """ + + overall_size = _compute_size_by_dict(idx_contraction, size_dictionary) + op_factor = max(1, num_terms - 1) + if inner: + op_factor += 1 + + return overall_size * op_factor + def _compute_size_by_dict(indices, idx_dict): """ Computes the product of the elements in indices based on the dictionary @@ -139,14 +179,9 @@ def _optimal_path(input_sets, output_set, idx_dict, memory_limit): iter_results = [] # Compute all unique pairs - comb_iter = [] - for x in range(len(input_sets) - iteration): - for y in range(x + 1, len(input_sets) - iteration): - comb_iter.append((x, y)) - for curr in full_results: cost, positions, remaining = curr - for con in comb_iter: + for con in itertools.combinations(range(len(input_sets) - iteration), 2): # Find the contraction cont = _find_contraction(con, remaining, output_set) @@ -157,15 +192,10 @@ def _optimal_path(input_sets, output_set, idx_dict, memory_limit): if new_size > memory_limit: continue - # Find cost - new_cost = _compute_size_by_dict(idx_contract, idx_dict) - if idx_removed: - new_cost *= 2 - # Build (total_cost, positions, indices_remaining) - new_cost += cost + total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict) new_pos = positions + [con] - iter_results.append((new_cost, new_pos, new_input_sets)) + iter_results.append((total_cost, new_pos, new_input_sets)) # Update combinatorial list, if we did not find anything return best # path + remaining contractions @@ -183,6 +213,102 @@ def _optimal_path(input_sets, output_set, idx_dict, memory_limit): path = min(full_results, key=lambda x: x[0])[1] return path +def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost): + """Compute the cost (removed size + flops) and resultant indices for + performing the contraction specified by ``positions``. + + Parameters + ---------- + positions : tuple of int + The locations of the proposed tensors to contract. + input_sets : list of sets + The indices found on each tensors. + output_set : set + The output indices of the expression. + idx_dict : dict + Mapping of each index to its size. + memory_limit : int + The total allowed size for an intermediary tensor. + path_cost : int + The contraction cost so far. + naive_cost : int + The cost of the unoptimized expression. + + Returns + ------- + cost : (int, int) + A tuple containing the size of any indices removed, and the flop cost. + positions : tuple of int + The locations of the proposed tensors to contract. + new_input_sets : list of sets + The resulting new list of indices if this proposed contraction is performed. + + """ + + # Find the contraction + contract = _find_contraction(positions, input_sets, output_set) + idx_result, new_input_sets, idx_removed, idx_contract = contract + + # Sieve the results based on memory_limit + new_size = _compute_size_by_dict(idx_result, idx_dict) + if new_size > memory_limit: + return None + + # Build sort tuple + old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions) + removed_size = sum(old_sizes) - new_size + + # NB: removed_size used to be just the size of any removed indices i.e.: + # helpers.compute_size_by_dict(idx_removed, idx_dict) + cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict) + sort = (-removed_size, cost) + + # Sieve based on total cost as well + if (path_cost + cost) > naive_cost: + return None + + # Add contraction to possible choices + return [sort, positions, new_input_sets] + + +def _update_other_results(results, best): + """Update the positions and provisional input_sets of ``results`` based on + performing the contraction result ``best``. Remove any involving the tensors + contracted. + + Parameters + ---------- + results : list + List of contraction results produced by ``_parse_possible_contraction``. + best : list + The best contraction of ``results`` i.e. the one that will be performed. + + Returns + ------- + mod_results : list + The list of modifed results, updated with outcome of ``best`` contraction. + """ + + best_con = best[1] + bx, by = best_con + mod_results = [] + + for cost, (x, y), con_sets in results: + + # Ignore results involving tensors just contracted + if x in best_con or y in best_con: + continue + + # Update the input_sets + del con_sets[by - int(by > x) - int(by > y)] + del con_sets[bx - int(bx > x) - int(bx > y)] + con_sets.insert(-1, best[2][-1]) + + # Update the position indices + mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by) + mod_results.append((cost, mod_con, con_sets)) + + return mod_results def _greedy_path(input_sets, output_set, idx_dict, memory_limit): """ @@ -219,46 +345,68 @@ def _greedy_path(input_sets, output_set, idx_dict, memory_limit): [(0, 2), (0, 1)] """ + # Handle trivial cases that leaked through if len(input_sets) == 1: return [(0,)] + elif len(input_sets) == 2: + return [(0, 1)] + + # Build up a naive cost + contract = _find_contraction(range(len(input_sets)), input_sets, output_set) + idx_result, new_input_sets, idx_removed, idx_contract = contract + naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict) + # Initially iterate over all pairs + comb_iter = itertools.combinations(range(len(input_sets)), 2) + known_contractions = [] + + path_cost = 0 path = [] - for iteration in range(len(input_sets) - 1): - iteration_results = [] - comb_iter = [] - # Compute all unique pairs - for x in range(len(input_sets)): - for y in range(x + 1, len(input_sets)): - comb_iter.append((x, y)) + for iteration in range(len(input_sets) - 1): + # Iterate over all pairs on first step, only previously found pairs on subsequent steps for positions in comb_iter: - # Find the contraction - contract = _find_contraction(positions, input_sets, output_set) - idx_result, new_input_sets, idx_removed, idx_contract = contract - - # Sieve the results based on memory_limit - if _compute_size_by_dict(idx_result, idx_dict) > memory_limit: + # Always initially ignore outer products + if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]): continue - # Build sort tuple - removed_size = _compute_size_by_dict(idx_removed, idx_dict) - cost = _compute_size_by_dict(idx_contract, idx_dict) - sort = (-removed_size, cost) + result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, + naive_cost) + if result is not None: + known_contractions.append(result) - # Add contraction to possible choices - iteration_results.append([sort, positions, new_input_sets]) + # If we do not have a inner contraction, rescan pairs including outer products + if len(known_contractions) == 0: - # If we did not find a new contraction contract remaining - if len(iteration_results) == 0: - path.append(tuple(range(len(input_sets)))) - break + # Then check the outer products + for positions in itertools.combinations(range(len(input_sets)), 2): + result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, + path_cost, naive_cost) + if result is not None: + known_contractions.append(result) + + # If we still did not find any remaining contractions, default back to einsum like behavior + if len(known_contractions) == 0: + path.append(tuple(range(len(input_sets)))) + break # Sort based on first index - best = min(iteration_results, key=lambda x: x[0]) - path.append(best[1]) + best = min(known_contractions, key=lambda x: x[0]) + + # Now propagate as many unused contractions as possible to next iteration + known_contractions = _update_other_results(known_contractions, best) + + # Next iteration only compute contractions with the new tensor + # All other contractions have been accounted for input_sets = best[2] + new_tensor_pos = len(input_sets) - 1 + comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos)) + + # Update path and total cost + path.append(best[1]) + path_cost += best[0][1] return path @@ -314,26 +462,27 @@ def _can_dot(inputs, result, idx_removed): if len(inputs) != 2: return False - # Build a few temporaries input_left, input_right = inputs + + for c in set(input_left + input_right): + # can't deal with repeated indices on same input or more than 2 total + nl, nr = input_left.count(c), input_right.count(c) + if (nl > 1) or (nr > 1) or (nl + nr > 2): + return False + + # can't do implicit summation or dimension collapse e.g. + # "ab,bc->c" (implicitly sum over 'a') + # "ab,ca->ca" (take diagonal of 'a') + if nl + nr - 1 == int(c in result): + return False + + # Build a few temporaries set_left = set(input_left) set_right = set(input_right) keep_left = set_left - idx_removed keep_right = set_right - idx_removed rs = len(idx_removed) - # Indices must overlap between the two operands - if not len(set_left & set_right): - return False - - # We cannot have duplicate indices ("ijj, jk -> ik") - if (len(set_left) != len(input_left)) or (len(set_right) != len(input_right)): - return False - - # Cannot handle partial inner ("ij, ji -> i") - if len(keep_left & keep_right): - return False - # At this point we are a DOT, GEMV, or GEMM operation # Handle inner products @@ -698,6 +847,7 @@ def einsum_path(*operands, **kwargs): # Get length of each unique dimension and ensure all dimensions are correct dimension_dict = {} + broadcast_indices = [[] for x in range(len(input_list))] for tnum, term in enumerate(input_list): sh = operands[tnum].shape if len(sh) != len(term): @@ -706,6 +856,11 @@ def einsum_path(*operands, **kwargs): % (input_subscripts[tnum], tnum)) for cnum, char in enumerate(term): dim = sh[cnum] + + # Build out broadcast indices + if dim == 1: + broadcast_indices[tnum].append(char) + if char in dimension_dict.keys(): # For broadcasting cases we always want the largest dim size if dimension_dict[char] == 1: @@ -717,6 +872,9 @@ def einsum_path(*operands, **kwargs): else: dimension_dict[char] = dim + # Convert broadcast inds to sets + broadcast_indices = [set(x) for x in broadcast_indices] + # Compute size of each input array plus the output array size_list = [] for term in input_list + [output_subscript]: @@ -730,20 +888,14 @@ def einsum_path(*operands, **kwargs): # Compute naive cost # This isn't quite right, need to look into exactly how einsum does this - naive_cost = _compute_size_by_dict(indices, dimension_dict) - indices_in_input = input_subscripts.replace(',', '') - mult = max(len(input_list) - 1, 1) - if (len(indices_in_input) - len(set(indices_in_input))): - mult *= 2 - naive_cost *= mult + inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0 + naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict) # Compute the path if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set): # Nothing to be optimized, leave it to einsum path = [tuple(range(len(input_list)))] elif path_type == "greedy": - # Maximum memory should be at most out_size for this algorithm - memory_arg = min(memory_arg, max_size) path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg) elif path_type == "optimal": path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg) @@ -762,18 +914,24 @@ def einsum_path(*operands, **kwargs): contract = _find_contraction(contract_inds, input_sets, output_set) out_inds, input_sets, idx_removed, idx_contract = contract - cost = _compute_size_by_dict(idx_contract, dimension_dict) - if idx_removed: - cost *= 2 + cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict) cost_list.append(cost) scale_list.append(len(idx_contract)) size_list.append(_compute_size_by_dict(out_inds, dimension_dict)) + bcast = set() tmp_inputs = [] for x in contract_inds: tmp_inputs.append(input_list.pop(x)) + bcast |= broadcast_indices.pop(x) - do_blas = _can_dot(tmp_inputs, out_inds, idx_removed) + new_bcast_inds = bcast - idx_removed + + # If we're broadcasting, nix blas + if not len(idx_removed & bcast): + do_blas = _can_dot(tmp_inputs, out_inds, idx_removed) + else: + do_blas = False # Last contraction if (cnum - len(path)) == -1: @@ -783,6 +941,7 @@ def einsum_path(*operands, **kwargs): idx_result = "".join([x[1] for x in sorted(sort_result)]) input_list.append(idx_result) + broadcast_indices.append(new_bcast_inds) einsum_str = ",".join(tmp_inputs) + "->" + idx_result contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas) @@ -1200,25 +1359,14 @@ def einsum(*operands, **kwargs): tmp_operands.append(operands.pop(x)) # Do we need to deal with the output? - if specified_out and ((num + 1) == len(contraction_list)): - handle_out = True + handle_out = specified_out and ((num + 1) == len(contraction_list)) - # Handle broadcasting vs BLAS cases + # Call tensordot if still possible if blas: # Checks have already been handled input_str, results_index = einsum_str.split('->') input_left, input_right = input_str.split(',') - if 1 in tmp_operands[0].shape or 1 in tmp_operands[1].shape: - left_dims = {dim: size for dim, size in - zip(input_left, tmp_operands[0].shape)} - right_dims = {dim: size for dim, size in - zip(input_right, tmp_operands[1].shape)} - # If dims do not match we are broadcasting, BLAS off - if any(left_dims[ind] != right_dims[ind] for ind in idx_rm): - blas = False - # Call tensordot if still possible - if blas: tensor_result = input_left + input_right for s in idx_rm: tensor_result = tensor_result.replace(s, "") diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index a72079218c74..8ce374a759e7 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -16,7 +16,7 @@ global_size_dict[char] = size -class TestEinSum(object): +class TestEinsum(object): def test_einsum_errors(self): for do_opt in [True, False]: # Need enough arguments @@ -614,7 +614,7 @@ def test_subscript_range(self): np.einsum(a, [0, 51], b, [51, 2], [0, 2], optimize=False) assert_raises(ValueError, lambda: np.einsum(a, [0, 52], b, [52, 2], [0, 2], optimize=False)) assert_raises(ValueError, lambda: np.einsum(a, [-1, 5], b, [5, 2], [-1, 2], optimize=False)) - + def test_einsum_broadcast(self): # Issue #2455 change in handling ellipsis # remove the 'middle broadcast' error @@ -735,19 +735,22 @@ def test_out_is_res(self): res = np.einsum('...ij,...jk->...ik', a, a, out=a) assert res is a - def optimize_compare(self, string): + def optimize_compare(self, subscripts, operands=None): # Tests all paths of the optimization function against # conventional einsum - operands = [string] - terms = string.split('->')[0].split(',') - for term in terms: - dims = [global_size_dict[x] for x in term] - operands.append(np.random.rand(*dims)) - - noopt = np.einsum(*operands, optimize=False) - opt = np.einsum(*operands, optimize='greedy') + if operands is None: + args = [subscripts] + terms = subscripts.split('->')[0].split(',') + for term in terms: + dims = [global_size_dict[x] for x in term] + args.append(np.random.rand(*dims)) + else: + args = [subscripts] + operands + + noopt = np.einsum(*args, optimize=False) + opt = np.einsum(*args, optimize='greedy') assert_almost_equal(opt, noopt) - opt = np.einsum(*operands, optimize='optimal') + opt = np.einsum(*args, optimize='optimal') assert_almost_equal(opt, noopt) def test_hadamard_like_products(self): @@ -833,8 +836,28 @@ def test_combined_views_mapping(self): b = np.einsum('bbcdc->d', a) assert_equal(b, [12]) + def test_broadcasting_dot_cases(self): + # Ensures broadcasting cases are not mistaken for GEMM -class TestEinSumPath(object): + a = np.random.rand(1, 5, 4) + b = np.random.rand(4, 6) + c = np.random.rand(5, 6) + d = np.random.rand(10) + + self.optimize_compare('ijk,kl,jl', operands=[a, b, c]) + self.optimize_compare('ijk,kl,jl,i->i', operands=[a, b, c, d]) + + e = np.random.rand(1, 1, 5, 4) + f = np.random.rand(7, 7) + self.optimize_compare('abjk,kl,jl', operands=[e, b, c]) + self.optimize_compare('abjk,kl,jl,ab->ab', operands=[e, b, c, f]) + + # Edge case found in gh-11308 + g = np.arange(64).reshape(2, 4, 8) + self.optimize_compare('obk,ijk->ioj', operands=[g, g]) + + +class TestEinsumPath(object): def build_operands(self, string, size_dict=global_size_dict): # Builds views based off initial operands @@ -880,7 +903,7 @@ def test_long_paths(self): long_test1 = self.build_operands('acdf,jbje,gihb,hfac,gfac,gifabc,hfac') path, path_str = np.einsum_path(*long_test1, optimize='greedy') self.assert_path_equal(path, ['einsum_path', - (1, 4), (2, 4), (1, 4), (1, 3), (1, 2), (0, 1)]) + (3, 6), (3, 4), (2, 4), (2, 3), (0, 2), (0, 1)]) path, path_str = np.einsum_path(*long_test1, optimize='optimal') self.assert_path_equal(path, ['einsum_path', @@ -889,10 +912,12 @@ def test_long_paths(self): # Long test 2 long_test2 = self.build_operands('chd,bde,agbc,hiad,bdi,cgh,agdb') path, path_str = np.einsum_path(*long_test2, optimize='greedy') + print(path) self.assert_path_equal(path, ['einsum_path', (3, 4), (0, 3), (3, 4), (1, 3), (1, 2), (0, 1)]) path, path_str = np.einsum_path(*long_test2, optimize='optimal') + print(path) self.assert_path_equal(path, ['einsum_path', (0, 5), (1, 4), (3, 4), (1, 3), (1, 2), (0, 1)]) @@ -926,7 +951,7 @@ def test_edge_paths(self): # Edge test4 edge_test4 = self.build_operands('dcc,fce,ea,dbf->ab') path, path_str = np.einsum_path(*edge_test4, optimize='greedy') - self.assert_path_equal(path, ['einsum_path', (0, 3), (0, 2), (0, 1)]) + self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 1), (0, 1)]) path, path_str = np.einsum_path(*edge_test4, optimize='optimal') self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 2), (0, 1)]) @@ -949,7 +974,7 @@ def test_path_type_input(self): self.assert_path_equal(path, ['einsum_path', (0, 1, 2, 3)]) path, path_str = np.einsum_path(*path_test, optimize=True) - self.assert_path_equal(path, ['einsum_path', (0, 3), (0, 2), (0, 1)]) + self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 1), (0, 1)]) exp_path = ['einsum_path', (0, 2), (0, 2), (0, 1)] path, path_str = np.einsum_path(*path_test, optimize=exp_path)