diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index ed40a080441bc..3b1fdb69e8ef1 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -10,6 +10,10 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" #include #include @@ -28,67 +32,329 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType, return std::nullopt; } -std::optional> -mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, - ArrayRef targetShape) { - if (sourceShape.size() <= targetShape.size()) - return std::nullopt; - unsigned sourceDim = 0; - SmallVector reassociationMap; - reassociationMap.reserve(targetShape.size()); +namespace { +/// A simple struct to represent ReassociationIndices as an inclusive interval. +/// It's designed to be feasibly minimal, so the call sites should manage the +/// validity of the range manually. +struct ReassociationIndexRange { + /// FIXME: Signed type is used for consistency with ReassociationIndices. + /// We should consider refactoring all reassociation utilities to use unsigned + /// types. + int64_t leftIdx = 0, rightIdx = 0; + + /// Util for manual checks of the range's validity + LogicalResult verify() const { + return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure(); + } + + /// Checks range's containment within another range. Treats the edges + /// non-exclusively. + bool isInRange(const ReassociationIndexRange &outerRange) const { + return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx; + } - ReassociationIndices currIndices; + unsigned size() const { + assert(succeeded(verify())); + return rightIdx - leftIdx + 1; + } + bool containsSingleIndex() const { return size() == 1; } + + /// Collects indices that do not overlap between this and another range. + ReassociationIndices + getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const { + if (rightIdx < rhs.leftIdx) { + // The intervals do not overlap - concatenate the indices from both. + auto jointFullIndices = getFullIndices(); + jointFullIndices.append(rhs.getFullIndices()); + return jointFullIndices; + } + ReassociationIndices result; + // Handle the chunk left of the overlapping range. + int64_t leftStart = std::min(leftIdx, rhs.leftIdx); + int64_t leftEnd = std::max(leftIdx, rhs.leftIdx); + llvm::append_range(result, llvm::seq(leftStart, leftEnd)); + // Handle the chunk right of the overlapping range. Symmetrically, we should + // skip the edge of the overlap AND include the rightmost index. + int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1; + int64_t rightEnd = std::max(rightIdx, rhs.rightIdx); + if (rightStart < rightEnd) + llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd)); + return result; + } + + /// Converts the range into ReassociationIndices. + ReassociationIndices getFullIndices() const { + ReassociationIndices result; + for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) { + result.push_back(idx); + } + return result; + } +}; +} // namespace + +/// Starting from `sourceStartIdx`, searches `sourceShape` for the first +/// sequence that can be collapsed into a dynamic dimension (at least one must +/// be present in the source). +/// By default, lazily returns once the first dynamic dimension has been found. +/// Setting `matchGreedily` as `true` will also mark all subsequent +/// source dimensions for collapsing into the target. +static FailureOr +findReassociationRangeForDynamicDim(ArrayRef sourceShape, + int64_t sourceStartIdx, + bool matchGreedily = false) { + const unsigned numSourceDims = sourceShape.size(); + ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; + std::optional resultRange = std::nullopt; + + ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx}; + for (; iterationRange.isInRange(sourceShapeAsRange); + iterationRange.rightIdx++) { + int64_t sourceSize = sourceShape[iterationRange.rightIdx]; + if (sourceSize == ShapedType::kDynamic) { + resultRange = iterationRange; + break; + } + } + if (!resultRange) + return failure(); + if (matchGreedily) + resultRange->rightIdx = sourceShapeAsRange.rightIdx; + return *resultRange; +} + +/// Starting from `sourceStartIdx`, searches `sourceShape` for the first +/// sequence of static dimensions such that their product matches `targetSize`. +/// By default, lazily returns once the product matches the target size. Setting +/// `matchGreedily` as `true` will append all neighboring unit dimensions +/// (dimensions of 1) to the match. +static FailureOr +findReassociationRangeForSize(ArrayRef sourceShape, + int64_t sourceStartIdx, int64_t targetSize, + bool matchGreedily = false) { + const unsigned numSourceDims = sourceShape.size(); + ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; + std::optional resultRange = std::nullopt; + + ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx}; int64_t prodOfCollapsedDims = 1; - while (sourceDim < sourceShape.size()) { - unsigned targetDim = reassociationMap.size(); - // If we have mapped all the target dimensions stop and handle the remaining - // tail of size-1 dimensions explicitly. - if (targetDim == targetShape.size()) + while (iterationRange.isInRange(sourceShapeAsRange)) { + int64_t sourceSize = sourceShape[iterationRange.rightIdx]; + if (sourceSize == ShapedType::kDynamic) { + // Reassociation for a static dim cannot include a dynamic dim. Reset + // induction variables to essentially restart the loop from the next + // source dimension. + prodOfCollapsedDims = 1; + iterationRange = {iterationRange.rightIdx + 1, + iterationRange.rightIdx + 1}; + continue; + } + prodOfCollapsedDims *= sourceSize; + // If the target size has been exceeded without matching, we need to shift + // the range start right. From the start of the range, roll back the + // multiplication until the target size exceeds the product again. + while (prodOfCollapsedDims > targetSize && + !iterationRange.containsSingleIndex()) { + int64_t frontSourceSize = sourceShape[iterationRange.leftIdx]; + prodOfCollapsedDims /= frontSourceSize; + // Shrink the range rightwards + iterationRange.leftIdx++; + } + // We could've reached the target size with the current dimension, + // also as a result of the above shift to right. + if (prodOfCollapsedDims == targetSize) { + resultRange = iterationRange; break; + } + // Increment the iteration range + iterationRange.rightIdx++; + } + if (!resultRange) + return failure(); + if (matchGreedily) { + // We now want to collect all unit dimensions directly after the target + // product match. Advance the iterator to avoid OOB when the product match + // happens at the last element. + iterationRange.rightIdx++; + while (iterationRange.isInRange(sourceShapeAsRange) && + sourceShape[iterationRange.rightIdx] == 1) { + resultRange = iterationRange; + iterationRange.rightIdx++; + } + } + return *resultRange; +} + +/// Attempts to find a valid collapsing reassociation of `sourceShape` into +/// `targetShape` through a simple traversal. If successful, an array of source +/// index ranges is returned, correspondingly to each dimension in the target +/// shape. The resulting indices shall fully cover the `sourceShape` without +/// overlaps. +/// +/// The algorithm is essentially a lazy one, searching for non-greedy matches - +/// it will only yield a greedy match for the last target dimension. +/// FIXME: The algorithm can only backtrack when it needs to append an offset +/// for a static target dimension to the preceding dynamic one (this retains the +/// linear complexity). As feasible, consider adding further backtracking +/// routines to enable more reassociations, e.g.: +/// - ?x2x?x2 into ?x2 +static FailureOr> +findReassociationRangesForCollapse(ArrayRef sourceShape, + ArrayRef targetShape) { + unsigned numSourceDims = sourceShape.size(), + numTargetDims = targetShape.size(); + assert(numSourceDims > numTargetDims); + ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; + + SmallVector reassocRanges; + reassocRanges.reserve(numTargetDims); + // We'll iterate in strides of 2 to enable pseudo-backtracking for simple + // cases, e.g.: + // - ?x2x3x5 into ?x15 + std::optional prevTargetSize = std::nullopt; + for (unsigned targetDimIdx = 0, sourceDimIdx = 0; + targetDimIdx < numTargetDims; ++targetDimIdx) { + int64_t targetSize = targetShape[targetDimIdx]; + // Simply check if there are any subsequent target dimensions left - if not, + // the match must be made greedily. + bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1; + FailureOr sourceRange; + if (targetSize == ShapedType::kDynamic) { + sourceRange = findReassociationRangeForDynamicDim( + sourceShape, sourceDimIdx, shouldMatchGreedily); + } else { + sourceRange = findReassociationRangeForSize( + sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily); + } - int64_t currTargetShape = targetShape[targetDim]; - while (sourceDim < (sourceShape.size() - 1) && - sourceShape[sourceDim] != ShapedType::kDynamic && - prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) { - prodOfCollapsedDims *= sourceShape[sourceDim]; - currIndices.push_back(sourceDim++); + // Run sanity checks on the returned index range. + if (failed(sourceRange) || failed(sourceRange->verify()) || + !sourceRange->isInRange(sourceShapeAsRange)) + return failure(); + if (sourceRange->leftIdx > sourceDimIdx) { + // If some source dimensions had to be skipped in order to find a match, + // they must be collapsed into the directly preceding dynamic dimension. + if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic) + return failure(); + reassocRanges.back().rightIdx = sourceRange->leftIdx - 1; } - // If the current expanded dimension is dynamic, then the collapsed - // dimensions should also be dynamic and product of all previous unprocessed - // dimensions of the expanded shape should be 1. - if (sourceShape[sourceDim] == ShapedType::kDynamic && - (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1)) - return std::nullopt; - - // If the collapsed dim is dynamic, the current expanded dim should also - // be dynamic. - if (currTargetShape == ShapedType::kDynamic && - sourceShape[sourceDim] != ShapedType::kDynamic) - return std::nullopt; - - // For static shapes, if the product of dimensions of the expanded shape - // should match the collapsed dimension shape. - if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) - return std::nullopt; - - currIndices.push_back(sourceDim++); - reassociationMap.emplace_back(ReassociationIndices{}); - std::swap(reassociationMap.back(), currIndices); - prodOfCollapsedDims = 1; + // Store the gathered information as required for the next iteration. + prevTargetSize = targetSize; + sourceDimIdx = sourceRange->rightIdx + 1; + reassocRanges.push_back(*sourceRange); + } + // Fail if the source shape wasn't a full match for the target shape. We only + // need to check the last recorded index - any other gaps should have been + // mended by the main loop. + if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx) + return failure(); + return reassocRanges; +} + +/// A variant of `findReassociationRangesForCollapse(...)` that can also scan +/// the shapes right-to-left. +static FailureOr> +findReassociationRangesForCollapse(ArrayRef sourceShape, + ArrayRef targetShape, + bool iterateRightToLeft) { + if (!iterateRightToLeft) + return findReassociationRangesForCollapse(sourceShape, targetShape); + // NB: To iterate right-to-left, we currently reverse the shapes and then + // reverse the result back. The reversed shapes must not be temporary, as + // we're passing through an ArrayRef. + // FIXME: It would be preferable to avoid the expensive copies. At the moment, + // this approach is chosen for readability of the main implementation. + std::vector sourceToReverse = sourceShape.vec(), + targetToReverse = targetShape.vec(); + std::reverse(sourceToReverse.begin(), sourceToReverse.end()); + std::reverse(targetToReverse.begin(), targetToReverse.end()); + auto invertedRanges = + findReassociationRangesForCollapse(sourceToReverse, targetToReverse); + if (failed(invertedRanges)) + return failure(); + SmallVector &rangesToInvert = *invertedRanges; + unsigned numSourceDims = sourceShape.size(); + // We have received the ranges for inverted shapes. Now we have to invert + // the ranges back to correspond with the original source shape. + for (auto &range : rangesToInvert) { + int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx; + range.leftIdx = numSourceDims - 1 - invRightIdx; + range.rightIdx = numSourceDims - 1 - invLeftIdx; } - // All the dimensions in the target must have been processed. - if (reassociationMap.size() != targetShape.size()) + // Also invert the ordering of the ranges to correspond with the original + // target shape. + std::reverse(rangesToInvert.begin(), rangesToInvert.end()); + return rangesToInvert; +} + +std::optional> +mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, + ArrayRef targetShape) { + unsigned numSourceDims = sourceShape.size(), + numTargetDims = targetShape.size(); + // We're supposed to search for a collapsing reassociation. If the sizes + // match, there's no actual collapsing taking place - it's either a no-op or a + // `tensor.reshape`-style reassociation (that would be beyond the scope of + // this utility). + if (numSourceDims <= numTargetDims) return std::nullopt; - // Process any remaining entries in the source shape. They all need to be - // 1 or dynamic. - for (; sourceDim < sourceShape.size(); sourceDim++) { - if (sourceShape[sourceDim] != ShapedType::kDynamic && - sourceShape[sourceDim] != 1) - return std::nullopt; - // The map is empty when the target type is a scalar. - if (!reassociationMap.empty()) - reassociationMap.back().push_back(sourceDim); + // Early handling for scalar target types. + if (numTargetDims == 0) { + ReassociationIndices allSourceIndices; + allSourceIndices.reserve(numSourceDims); + for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims; + ++sourceDimIdx) { + int64_t sourceSize = sourceShape[sourceDimIdx]; + // All source dimensions must be unit or dynamic. + if (sourceSize != 1 && sourceSize != ShapedType::kDynamic) + return std::nullopt; + allSourceIndices.push_back(sourceDimIdx); + } + return SmallVector{allSourceIndices}; + } + + // Collect source ranges by iterating over the target shape left-to-right. + FailureOr> maybeForwardRanges = + findReassociationRangesForCollapse(sourceShape, targetShape); + if (failed(maybeForwardRanges)) + return std::nullopt; + auto &ranges = *maybeForwardRanges; + // Now do the same in reverse. We need to get another valid reassociation + // through some other strategy, and then compare the results in order to + // disambiguate mixed subshapes, such as: + // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x? + // This leads us to lose some of the reassociation opportunities that can only + // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without + // backtracking, the algorithm will fail right-to-left. However, this is the + // best way to preserve correctness. + FailureOr> maybeReverseRanges = + findReassociationRangesForCollapse(sourceShape, targetShape, + /*iterateRightToLeft=*/true); + if (failed(maybeReverseRanges)) + return std::nullopt; + auto &reverseRanges = *maybeReverseRanges; + + if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims) + return std::nullopt; + // Now we can check for ambiguity of each target dimension's reassociation. If + // successful, we put the full indices into our result map for the target + // shape. + SmallVector reassociationMap(numTargetDims); + for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims; + ++targetDimIdx) { + ReassociationIndexRange &range = ranges[targetDimIdx]; + ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx]; + // Get non-overlapping indices between the ranges + ReassociationIndices nonMatchingIndices = + range.getNonOverlappingIndicesWith(reverseRange); + // Unit dimensions can be collapsed wherever - this is the only ambiguity + // that we allow. + for (int64_t sourceDimIdx : nonMatchingIndices) { + if (sourceShape[sourceDimIdx] != 1) + return std::nullopt; + } + reassociationMap[targetDimIdx] = range.getFullIndices(); } return reassociationMap; } @@ -315,11 +581,11 @@ SmallVector SliceFromCollapseHelper::getExtractSliceParams( // have proven that these are not sliced. In this case we just take // the full extent of each dimension in the reassociation list. if (linearizedDimensions[it.index()]) { - llvm::append_range( - offsetsSizesAndStrides, - llvm::map_range(it.value(), [&](int64_t idx) -> Range { - return {zeroAttr, collapseShapeInputShape[idx], oneAttr}; - })); + llvm::append_range(offsetsSizesAndStrides, + llvm::map_range(it.value(), [&](int64_t idx) -> Range { + return {zeroAttr, collapseShapeInputShape[idx], + oneAttr}; + })); continue; } diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir index 51350e5bc8498..6979770154bab 100644 --- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir @@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> { // ----- // CHECK-LABEL: func.func @unpack_dynamic -// CHECK-NOT: tensor.collapse -// CHECK: linalg.unpack +// CHECK: tensor.collapse +// CHECK-NOT: linalg.unpack func.func @unpack_dynamic(%arg0: tensor) -> tensor { %c32 = arith.constant 32 : index %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index cdcd7f305d2d9..3eaf824b99115 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1094,7 +1094,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3 // ----- -func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor, %arg1: index, %arg2: index) +func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor, %arg1: index, %arg2: index) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor into tensor @@ -1102,12 +1102,28 @@ func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor, %arg1: ind : tensor into tensor return %1 : tensor } -// CHECK-LABEL: @fold_expand_of_collapse_dynamic +// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape // CHECK-NOT: tensor.{{.*}}_shape // ----- -func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor, %arg1: index, %arg2: index, %arg3: index) +func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor, %arg1: index, %arg2: index) + -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] + : tensor into tensor + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape +// CHECK-NOT: tensor.expand_shape +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]] +// CHECK-SAME: : tensor into tensor +// CHECK-NEXT: return %[[COLLAPSE]] + +// ----- + +func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor, %arg1: index, %arg2: index, %arg3: index) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor into tensor @@ -1115,7 +1131,22 @@ func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor, %arg1: : tensor into tensor return %1 : tensor } -// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic +// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic +// CHECK: tensor.collapse_shape +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape +// CHECK: return %[[EXPAND]] + +// ----- + +func.func @no_fold_expand_of_collapse_adjacent_dynamic(%arg0 : tensor, %arg1: index, %arg2: index) + -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1, 2]] + : tensor into tensor + %1 = tensor.expand_shape %0 [[0, 1]] output_shape [%arg1, %arg2] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @no_fold_expand_of_collapse_adjacent_dynamic // CHECK: tensor.collapse_shape // CHECK: %[[EXPAND:.+]] = tensor.expand_shape // CHECK: return %[[EXPAND]] diff --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt b/mlir/unittests/Dialect/Utils/CMakeLists.txt index 61b9cdcb3b8f3..e921c8bcfb4e5 100644 --- a/mlir/unittests/Dialect/Utils/CMakeLists.txt +++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIRDialectUtilsTests StructuredOpsUtilsTest.cpp + ReshapeOpsUtilsTest.cpp IndexingUtilsTest.cpp ) mlir_target_link_libraries(MLIRDialectUtilsTests diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp new file mode 100644 index 0000000000000..db1a87a4de2d5 --- /dev/null +++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp @@ -0,0 +1,203 @@ +//===- ReshapeOpsUtilsTest.cpp - ReshapeOpsUtils unit tests ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "llvm/ADT/STLExtras.h" +#include "gtest/gtest.h" +#include + +using namespace mlir; + +/// Helper to make constructing +/// `std::optional>` more readable. +static std::optional> +makeOptionalIndices(std::initializer_list list) { + return std::optional>(list); +} + +TEST(ReassociationIndicesForCollapse, ScalarTest) { + EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}), + makeOptionalIndices({{0}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}), + makeOptionalIndices({{0}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, + ShapedType::kDynamic, 1, + ShapedType::kDynamic}, + {}), + makeOptionalIndices({{0, 1, 2, 3, 4}})); +} + +TEST(ReassociationIndicesForCollapse, ScalarTestFailure) { + EXPECT_EQ(getReassociationIndicesForCollapse({}, {}), std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse({}, {1}), std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse({2}, {}), std::nullopt); + EXPECT_EQ( + getReassociationIndicesForCollapse({1, 2, ShapedType::kDynamic, 1}, {}), + std::nullopt); +} + +TEST(ReassociationIndicesForCollapse, StaticTest) { + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {10, 600}), + makeOptionalIndices({{0}, {1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 30}), + makeOptionalIndices({{0, 1}, {2}})); +} + +TEST(ReassociationIndicesForCollapse, StaticTestFailure) { + // No-op reassociation + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10, 20}), + std::nullopt); + // Invalid static reassociations + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 300}), + std::nullopt); + // Non-collapsing (expanding) reassociation + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {1, 10, 20, 30}), + std::nullopt); +} + +TEST(ReassociationIndicesForCollapse, StaticTestUnitDims) { + EXPECT_EQ(getReassociationIndicesForCollapse({10, 1}, {10}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, 20, 30}, {600}), + makeOptionalIndices({{0, 1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1}), + makeOptionalIndices({{0, 1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1, 1}, {1, 1, 1}), + makeOptionalIndices({{0}, {1}, {2, 3}})); +} + +TEST(ReassociationIndicesForCollapse, DynamicTest) { + EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1}, + {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1}, + {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {1, ShapedType::kDynamic, 1, ShapedType::kDynamic, 1}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}, {2, 3, 4}})); + EXPECT_EQ( + getReassociationIndicesForCollapse( + {ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {1, ShapedType::kDynamic, ShapedType::kDynamic}, + {1, ShapedType::kDynamic}), + makeOptionalIndices({{0}, {1, 2}})); + + EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10}, + {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {1, ShapedType::kDynamic, ShapedType::kDynamic}, + {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic}, + {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 1, 2, ShapedType::kDynamic, 10}, + {ShapedType::kDynamic, 10}), + makeOptionalIndices({{0, 1, 2, 3}, {4}})); + EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20}, + {ShapedType::kDynamic, 20}), + makeOptionalIndices({{0, 1}, {2}})); + EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic, 20}, + {ShapedType::kDynamic, 20}), + makeOptionalIndices({{0, 1}, {2}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 3, 2, 5, 2}, {ShapedType::kDynamic, 20}), + makeOptionalIndices({{0, 1}, {2, 3, 4}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {10, ShapedType::kDynamic, 20, ShapedType::kDynamic, 1}, + {ShapedType::kDynamic, 20, ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}, {2}, {3, 4}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1}, + {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, ShapedType::kDynamic, 1}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + makeOptionalIndices({{0}, {1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {1, ShapedType::kDynamic, ShapedType::kDynamic}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}, {2}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 1, ShapedType::kDynamic}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + makeOptionalIndices({{0}, {1, 2}})); +} + +TEST(ReassociationIndicesForCollapse, DynamicTestFailure) { + EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20}, + {ShapedType::kDynamic, 10}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 10, ShapedType::kDynamic}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {20, ShapedType::kDynamic, 10, ShapedType::kDynamic}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 5, 3, 2, 2}, {ShapedType::kDynamic, 20}), + std::nullopt); + EXPECT_EQ( + getReassociationIndicesForCollapse( + {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, ShapedType::kDynamic, 10, 1, + ShapedType::kDynamic}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic}, + {ShapedType::kDynamic, 10, ShapedType::kDynamic}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic}, + {ShapedType::kDynamic, 2, 2, ShapedType::kDynamic}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 3, 4, 3, ShapedType::kDynamic}, + {ShapedType::kDynamic, 12, ShapedType::kDynamic}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic}, + {ShapedType::kDynamic, 32, ShapedType::kDynamic}), + std::nullopt); + + //===----------------------------------------------------------------------===// + // TODO: Reassociation for the following examples can be computed, but isn't + // supported by `getReassociationIndicesForCollapse`. + //===----------------------------------------------------------------------===// + + // TODO: Fails because there's no backtracking when some source dimensions + // remain unmatched at either edge. + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 10, ShapedType::kDynamic, 10}, + {ShapedType::kDynamic, 10}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 2, 2}, + {1, ShapedType::kDynamic, 2}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse({2, 2, ShapedType::kDynamic, 1}, + {2, ShapedType::kDynamic}), + std::nullopt); +}