-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][tensor] Loosen restrictions on folding dynamic reshapes #137963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The main idea behind the change is to allow expand-of-collapse folds for reshapes like `?x?xk` -> `?` (k>1). The rationale here is that the expand op must have a coherent index/affine expression specified in its `output_shape` argument (see example below), and if it doesn't, the IR has already been invalidated at an earlier stage: ``` %c32 = arith.constant 32 : index %div = arith.divsi %<some_index>, %c32 : index %collapsed = tensor.collapse_shape %41#1 [[0], [1, 2], [3, 4]] : tensor<9x?x32x?x32xf32> into tensor<9x?x?xf32> %affine = affine.apply affine_map<()[s0] -> (s0 * 32)> ()[%div] %expanded = tensor.expand_shape %collapsed [[0], [1, 2], [3]] output_shape [9, %div, 32, %affine] : tensor<9x?x?xf32> into tensor<9x?x32x?xf32> ``` On the above assumption, adjust the routine in `getReassociationIndicesForCollapse()` to allow dynamic reshapes beyond just `?x..?x1x1x..x1` -> `?`. Moreover, the reassociation util was refactored to clearly distinguish between dynamic and static subshapes. A few known caveats were noted as a comment; it doesn't seem possible to fold all qualifying dynamic shape patterns in a deterministic way without looking into affine expressions simultaneously. That would be difficult to maintain in a single general utility. Other implementation ideas/larger refactoring could include: - abandoning the util usage in the `ComposeExpandOfCollapseOp` pattern, employing similar logic to `ComposeCollapseOfExpandOp`; - providing dialect-specific implementations for Linalg/Tensor. Signed-off-by: Artem Gindinson <[email protected]>
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment βPingβ. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Artem Gindinson (AGindinson) ChangesThe main idea behind the change is to allow expand-of-collapse folds for reshapes like
On the above assumption, adjust the routine in Moreover, the reassociation util was refactored to clearly distinguish between dynamic and static subshapes. A few known caveats were noted as a comment; it doesn't seem possible to fold all qualifying dynamic shape patterns in a deterministic way without looking into affine expressions simultaneously. That would be difficult to maintain in a single general utility. Other implementation ideas/larger refactoring could include:
Signed-off-by: Artem Gindinson <[email protected]> Full diff: https://github.com/llvm/llvm-project/pull/137963.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index ed40a080441bc..694783849198a 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -31,59 +31,70 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> targetShape) {
- if (sourceShape.size() <= targetShape.size())
+ unsigned numSourceDims = sourceShape.size(),
+ numTargetDims = targetShape.size();
+ if (numSourceDims <= numTargetDims)
return std::nullopt;
- unsigned sourceDim = 0;
- SmallVector<ReassociationIndices> reassociationMap;
- reassociationMap.reserve(targetShape.size());
-
- ReassociationIndices currIndices;
- int64_t prodOfCollapsedDims = 1;
- while (sourceDim < sourceShape.size()) {
- unsigned targetDim = reassociationMap.size();
- // If we have mapped all the target dimensions stop and handle the remaining
- // tail of size-1 dimensions explicitly.
- if (targetDim == targetShape.size())
- break;
+ SmallVector<ReassociationIndices, 4> reassociationMap;
+ reassociationMap.reserve(numTargetDims);
+ unsigned sourceDim = 0, targetDim = 0;
+ for (; targetDim < numTargetDims; ++targetDim) {
int64_t currTargetShape = targetShape[targetDim];
- while (sourceDim < (sourceShape.size() - 1) &&
- sourceShape[sourceDim] != ShapedType::kDynamic &&
- prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
+ ReassociationIndices currIndices;
+ // 1. Target dimension is dynamic. Source shape should contain at least
+ // one dynamic dimension.
+ if (currTargetShape == ShapedType::kDynamic) {
+ // FIXME: We stop the search with the first dynamic dimension, while in
+ // fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes
+ // indeterministic altogether when we have neighboring dynamic dimensions
+ // in the target shape. Most of these patterns will be safely rejected,
+ // however we might achieve more correct folds by taking affine
+ // expressions into account, if these can be passed on by the call sites.
+ bool foundDynamic = false;
+ while (sourceDim < numSourceDims) {
+ currIndices.push_back(sourceDim);
+ if (sourceShape[sourceDim++] == ShapedType::kDynamic) {
+ foundDynamic = true;
+ break;
+ }
+ }
+ if (!foundDynamic)
+ return std::nullopt;
+
+ reassociationMap.push_back(currIndices);
+ continue;
+ }
+ // 2. Target dimension is static. The product of dimensions of the expanded
+ // shape should match the collapsed dimension shape.
+ int64_t prodOfCollapsedDims = 1;
+ bool reachedTargetDimSize = false;
+ while (sourceDim < numSourceDims) {
+ // Source shape cannot be dynamic if the target dim is static.
+ if (sourceShape[sourceDim] == ShapedType::kDynamic)
+ return std::nullopt;
prodOfCollapsedDims *= sourceShape[sourceDim];
- currIndices.push_back(sourceDim++);
+ if (prodOfCollapsedDims > currTargetShape)
+ break;
+ else if (prodOfCollapsedDims == currTargetShape) {
+ currIndices.push_back(sourceDim++);
+ reachedTargetDimSize = true;
+ break;
+ } else // prodOfCollapsedDims < currTargetShape
+ currIndices.push_back(sourceDim++);
}
-
- // If the current expanded dimension is dynamic, then the collapsed
- // dimensions should also be dynamic and product of all previous unprocessed
- // dimensions of the expanded shape should be 1.
- if (sourceShape[sourceDim] == ShapedType::kDynamic &&
- (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
+ if (!reachedTargetDimSize)
return std::nullopt;
-
- // If the collapsed dim is dynamic, the current expanded dim should also
- // be dynamic.
- if (currTargetShape == ShapedType::kDynamic &&
- sourceShape[sourceDim] != ShapedType::kDynamic)
- return std::nullopt;
-
- // For static shapes, if the product of dimensions of the expanded shape
- // should match the collapsed dimension shape.
- if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
- return std::nullopt;
-
- currIndices.push_back(sourceDim++);
- reassociationMap.emplace_back(ReassociationIndices{});
- std::swap(reassociationMap.back(), currIndices);
- prodOfCollapsedDims = 1;
+ reassociationMap.push_back(currIndices);
}
- // All the dimensions in the target must have been processed.
- if (reassociationMap.size() != targetShape.size())
- return std::nullopt;
- // Process any remaining entries in the source shape. They all need to be
- // 1 or dynamic.
- for (; sourceDim < sourceShape.size(); sourceDim++) {
- if (sourceShape[sourceDim] != ShapedType::kDynamic &&
+ // Now that we've mapped all the target dimensions, process any remaining
+ // entries in the source shape explicitly. Either the last target dimension
+ // is dynamic, or all remaining source entries need to be 1 or dynamic. Same
+ // applies when target shape is empty (can be the case for subshape
+ // reassociations).
+ for (; sourceDim < numSourceDims; sourceDim++) {
+ if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) &&
+ sourceShape[sourceDim] != ShapedType::kDynamic &&
sourceShape[sourceDim] != 1)
return std::nullopt;
// The map is empty when the target type is a scalar.
diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 51350e5bc8498..6979770154bab 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
// -----
// CHECK-LABEL: func.func @unpack_dynamic
-// CHECK-NOT: tensor.collapse
-// CHECK: linalg.unpack
+// CHECK: tensor.collapse
+// CHECK-NOT: linalg.unpack
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 85bf6fba52aa4..443f931745557 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1068,7 +1068,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3
// -----
-func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
-> tensor<?x4x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1076,12 +1076,28 @@ func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: ind
: tensor<?x?xf32> into tensor<?x4x?xf32>
return %1 : tensor<?x4x?xf32>
}
-// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
// CHECK-NOT: tensor.{{.*}}_shape
// -----
-func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
+ -> tensor<?x4x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
+ : tensor<?x4x?x2xf32> into tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+ : tensor<?x?xf32> into tensor<?x4x?xf32>
+ return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
+// CHECK-NOT: tensor.expand_shape
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
+// CHECK-SAME: : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
+// CHECK-NEXT: return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
-> tensor<?x?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
@@ -1089,7 +1105,7 @@ func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1:
: tensor<?x?xf32> into tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
-// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
// CHECK: tensor.collapse_shape
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
// CHECK: return %[[EXPAND]]
|
Signed-off-by: Artem Gindinson <[email protected]>
ReassociationIndices currIndices; | ||
// 1. Target dimension is dynamic. Source shape should contain at least | ||
// one dynamic dimension. | ||
if (currTargetShape == ShapedType::kDynamic) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There may be a pre-existing issue here: we should bail out when collapsing shapes that contain two or more adjacent dynamic dimensions, since the reassociation becomes indeterminate (as mentioned in the code comment below).
For example, the reshape ?x?x? -> ?x? could validly map to either:
- [[0], [1, 2]], or
- [[0, 1], [2]]
Both are valid so the reassoc cannot be determined. Here's a gist with mlir testcases.
As far as I can tell, if the output contains no adjacent dynamic dims the reassociation should be uniquely inferable. In that case, it might be worth considering a backtracking algorithm to fully generalize getReassociationIndicesForCollapse
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All good points, thanks!
In its current form, the algorithm safely bails out on ?x?x?
-> ?x?
. I'd actually tested a no_fold_expand_of_collapse_adjacent_dynamic
example manually but forgot to add it as a negative test - worth doing.
I believe a lot of real-world cases would actually be inferable based on output_shape
-contained affine expressions, but as mentioned, this would require a larger refactor. I'll try and reword that FIXME
below to clearly distinguish between what we can/cannot support going forward.
Backtracking of some sort should be doable - I'll see if I can come up with something. The previous version couldn't provide it either in the cases that I had tried out - would you prefer to see a rewrite in scope of the current PR, or have a TODO + follow-up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In its current form, the algorithm safely bails out on ?x?x? -> ?x?.
Maybe I missed something but I thought it was folding in this case.
would you prefer to see a rewrite in scope of the current PR
If this also fixes ?x?x? -> ?x?
it probably makes sense to have a followup for a generalized solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-checked the adjacent dynamic case, and I've clearly got lost in my own cases - it does fold right now, so I'll look for a way to disable that in the current form of the algorithm. Sorry for confusion.
Backtracking for mixed shape folds might fit in with that, so I'll try to incorporate that directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While looking into a more generalized algorithm, I've discovered some examples where the assumption doesn't apply:
if the output contains no adjacent dynamic dims the reassociation should be uniquely inferable
?x2x8x3x2x? into ?x48x?
would be a case where relieving the legality checks for mixed subshapes per my proposal makes the reassociation non-deterministic. Here, the original algorithm would attempt [[0], [1, 2, 3], ...]
and then just freak out at 2x? -> ?
. A proper implementation of my version should instead determine that there's more than one valid reassociation for the static slice and early-exit based on that.
Then, on a valid example like ?x5x8x3x2x? into ?x48x?
I would have to retry the reassociation after the initial failure to map 5x8x3
onto 48
. And for an invalid reshape like ?x8x3x1x1x1x1x5x2x? into ?x48x?
that would take "a while", because any initial analyses of denominators wouldn't detect the sequential pattern.
So even if I manage to early exit for a bulk of such cases, at the end of the day we still end up with O(n^2)
complexity* for a bunch of valid reshapes. Although this isn't necessarily a problem, because most of the subshape ranks we're likely to deal with are in single digits.
*omitting stuff like ?x2x4x?x3x5x?x...x? into ?x8x?x15x...
, because it could be divided-and-conquered by mapping the purely static slices between the source & target.
The bottom line: a general, abstract algorithm is possible, but I'm no longer convinced it's actually worth it for tensor reshapes. It might be better I take the other approach:
abandoning the util usage in the ComposeExpandOfCollapseOp pattern, employing similar logic to ComposeCollapseOfExpandOp
The main idea behind the change is to allow expand-of-collapse folds for reshapes like
?x?xk
->?
(k>1). The rationale here is that the expand op must have a coherent index/affine expression specified in itsoutput_shape
argument (see example below), and if it doesn't, the IR has already been invalidated at an earlier stage:On the above assumption, adjust the routine in
getReassociationIndicesForCollapse()
to allow dynamic reshapes beyond just?x..?x1x1x..x1
->?
.Moreover, the reassociation util was refactored to clearly distinguish between dynamic and static subshapes. A few known caveats were noted as a comment; it doesn't seem possible to fold all qualifying dynamic shape patterns in a deterministic way without looking into affine expressions simultaneously. That would be difficult to maintain in a single general utility. Other implementation ideas/larger refactoring could include:
ComposeExpandOfCollapseOp
pattern, employing similar logic toComposeCollapseOfExpandOp
;Signed-off-by: Artem Gindinson [email protected]