-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][vector] Canonicalize/fold 'order preserving' transposes #135841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: James Newling (newling) ChangesHandles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into. For example
can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that
The pattern Full diff: https://github.com/llvm/llvm-project/pull/135841.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..5da0ef0af032f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5621,6 +5621,29 @@ LogicalResult ShapeCastOp::verify() {
return success();
}
+namespace {
+
+/// Return true if `transpose` does not permute a pair of dimensions that are
+/// both not of size 1. By `order preserving` we mean that the flattened
+/// versions of the input and output vectors are (numerically) identical.
+/// In other words `transpose` is effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose) {
+ ArrayRef<int64_t> permutation = transpose.getPermutation();
+ ArrayRef<int64_t> inShape = transpose.getSourceVectorType().getShape();
+ int64_t current = 0;
+ for (auto p : permutation) {
+ if (inShape[p] != 1) {
+ if (p < current) {
+ return false;
+ }
+ current = p;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// No-op shape cast.
@@ -5629,13 +5652,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType resultType = getType();
- // Canceling shape casts.
- if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-
- // Only allows valid transitive folding (expand/collapse dimensions).
- VectorType srcType = otherOp.getSource().getType();
+ // shape_cast(something(x)) -> x, or
+ // -> shape_cast(x).
+ //
+ // Confirms that a new shape_cast will have valid semantics (expands OR
+ // collapses dimensions).
+ auto maybeFold = [&](TypedValue<VectorType> source) -> OpFoldResult {
+ VectorType srcType = source.getType();
if (resultType == srcType)
- return otherOp.getSource();
+ return source;
if (srcType.getRank() < resultType.getRank()) {
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
@@ -5645,8 +5670,25 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
} else {
return {};
}
- setOperand(otherOp.getSource());
+ setOperand(source);
return getResult();
+ };
+
+ // Canceling shape casts.
+ if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
+ TypedValue<VectorType> source = otherOp.getSource();
+ return maybeFold(source);
+ }
+
+ // shape_cast(transpose(x)) -> shape_cast(x)
+ if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
+ if (transpose.getType().isScalable())
+ return {};
+ if (isOrderPreserving(transpose)) {
+ TypedValue<VectorType> source = transpose.getVector();
+ return maybeFold(source);
+ }
+ return {};
}
// Cancelling broadcast and shape cast ops.
@@ -5675,7 +5717,7 @@ namespace {
/// Helper function that computes a new vector type based on the input vector
/// type by removing the trailing one dims:
///
-/// vector<4x1x1xi1> --> vector<4x1>
+/// vector<4x1x1xi1> --> vector<4x1xi1>
///
static VectorType trimTrailingOneDims(VectorType oldType) {
ArrayRef<int64_t> oldShape = oldType.getShape();
@@ -6161,12 +6203,40 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(shape_cast) into a new shape_cast.
+class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto shapeCastOp =
+ transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
+ if (!shapeCastOp)
+ return failure();
+ if (!isOrderPreserving(transposeOp))
+ return failure();
+ if (transposeOp.getType().isScalable())
+ return failure();
+
+ VectorType resultType = transposeOp.getType();
+
+ // We don't need to check isValidShapeCast at this point, because it is
+ // guaranteed that merging the transpose into the the shape_cast is a valid
+ // shape_cast, because the transpose just inserts/removes ones.
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
+ shapeCastOp.getSource());
+ return success();
+ }
+};
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
- TransposeFolder, FoldTransposeSplat>(context);
+ FoldTransposeShapeCast, TransposeFolder, FoldTransposeSplat>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..10144cb9034e4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3295,3 +3295,67 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
return %res : vector<4x1xi32>
}
+
+// -----
+
+// In this test, the permutation maps the non-one dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// CHECK-LABEL: @transpose_shape_cast
+// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x4xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+ %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+ : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
+ %1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8>
+ return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// In this test, the mapping of non-one indices (1 and 2) is as follows:
+// 1 -> 2
+// 2 -> 1
+// As this is not increasing (2 > 1), this transpose is not order
+// preserving and cannot be treated as a shape_cast.
+// CHECK-LABEL: @negative_transpose_shape_cast
+// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
+// CHECK: return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+ %0 = vector.transpose %arg, [0, 2, 1, 3]
+ : vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
+ %1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8>
+ return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @shape_cast_transpose
+// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<2x3x1x1xi8> to vector<6x1x1xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<6x1x1xi8>
+func.func @shape_cast_transpose(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+ %0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
+ %1 = vector.transpose %0, [0, 2, 1]
+ : vector<6x1x1xi8> to vector<6x1x1xi8>
+ return %1 : vector<6x1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_shape_cast_transpose
+// CHECK-SAME: %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
+// CHECK: return %[[TRANSPOSE]] : vector<2x3xi8>
+func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8> {
+ %0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
+ %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
+ return %1 : vector<2x3xi8>
+}
|
b21a4a6
to
f4ae206
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, makes sense!
Quite possibly there’s no need to be so strict with "scalable" vectors (thanks for checking!) — let me experiment and confirm, probably tomorrow.
Btw:
- "not of size 1" vs "non-one dimensions" vs "non-one indices"
We tend to say "non-unit dims" 😉 I hope I'm not being overly pedantic by suggesting using that wording in this PR as well. (Also, ESL here — so feel free to suggest something else if you prefer!)
Thanks!
This is useful info, I don't want to create unnecessary fragmentation. Will update this PR. |
I was able to run all end-to-end tests targeting scalable vectors using your patch (after commenting out both instances of if ( Could you add some dedicated tests for scalable vectors? Also, you’ll want to disable these transformations when encountering scalable unit dims, i.e., Lastly, would you mind adding some labels (or block comments) to "vector-transpose.mlir" to separate the tests for Thanks again! |
Thanks @banach-space I'll address these comments, but just to mention that there still is test that fails (hangs) for me, it is The problem is that the pattern ConvertIllegalShapeCastOpsToTransposes converts
into
but the canonicalizer added here does the exact reverse! So we get an infinite loop. The place where this transformation happens:
I've tried to dig a bit deeper, but don't see where shape_cast gets further lowered (I was wondering if |
In that case, I think the right approach is to first remove Sorry for the back-and-forth, but let’s keep things as-is for now. Would you mind leaving a TODO for me to revisit and relax these constraints? Ideally, reference Also, would you mind replacing |
I think this approach makes sense. I've made these updates (and am about to push them). I realized that only one of the 2 patterns needs to have scalable vectors disabled, so I've enabled the other one (i.e. transpose(shape_cast) -> shape_cast can handle scalable vectors fine, I didn't need to fail in this case too). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made one final request - otherwise, this LGTM. Thanks!
(Approving as-is, but please address that final comment.)
By the way, I’ll be away next week. Absolutely no need to wait for me to review or approve anything - just letting you know so you don’t feel blocked by a lack of reviews. And I have a feeling you'll be sending more patches soon anyway :)
ArrayRef<int64_t> inShape = transpose.getSourceVectorType().getShape(); | ||
int64_t current = 0; | ||
for (auto p : permutation) { | ||
if (inShape[p] != 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic should also check the scalable flags - otherwise [1]
will be treated as 1
. Basically, we should disable it for things like this:
%1 = vector.transpose %0, [0, 2, 1]
: vector<6x1x[1]xi8> to vector<6x[1]x1xi8>
We are very unlikely to ever generate this, but it's better to be safe than sorry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, sorry you mentioned that earlier too. Will fix this.
Ok thanks for letting me know. And thank you so much for all the reviewing you've been doing on my PRs :) |
ec66724
to
11fd6c3
Compare
@@ -5575,6 +5575,34 @@ LogicalResult ShapeCastOp::verify() { | |||
return success(); | |||
} | |||
|
|||
namespace { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a case for static
: https://llvm.org/docs/CodingStandards.html#restrict-visibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I intend to change this in a follow-up PR.
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable | ||
// vectors, so by disabling this folder for scalable vectors the | ||
// cycle is avoided. | ||
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why we generate illegal shape cast ops in first place? It sounds like something that shouldn't happen...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're illegal according to an arm specific lowering target, which I'm not familiar with
struct ConvertIllegalShapeCastOpsToTransposes |
I think @banach-space suspects that actually they're not illegal and will investigate the removal of this constraint here (and the pattern ConvertIllegalShapeCastOpsToTransposes).
…135841) Handles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into. For example ``` %1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x1x1x6xf32> to vector<1x4x6x1xf32> ``` can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that 1) shape_cast is more canonical than shape_cast(transpose) 2) shape_cast is more canonical than transpose(shape_cast) The pattern `ConvertIllegalShapeCastOpsToTransposes` that is specific to transposes with scalable dimensions reverses the canonicalization added here, so I've I've disabled this canonicalization for scalable vectors
…135841) Handles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into. For example ``` %1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x1x1x6xf32> to vector<1x4x6x1xf32> ``` can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that 1) shape_cast is more canonical than shape_cast(transpose) 2) shape_cast is more canonical than transpose(shape_cast) The pattern `ConvertIllegalShapeCastOpsToTransposes` that is specific to transposes with scalable dimensions reverses the canonicalization added here, so I've I've disabled this canonicalization for scalable vectors
…135841) Handles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into. For example ``` %1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x1x1x6xf32> to vector<1x4x6x1xf32> ``` can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that 1) shape_cast is more canonical than shape_cast(transpose) 2) shape_cast is more canonical than transpose(shape_cast) The pattern `ConvertIllegalShapeCastOpsToTransposes` that is specific to transposes with scalable dimensions reverses the canonicalization added here, so I've I've disabled this canonicalization for scalable vectors
…135841) Handles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into. For example ``` %1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x1x1x6xf32> to vector<1x4x6x1xf32> ``` can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that 1) shape_cast is more canonical than shape_cast(transpose) 2) shape_cast is more canonical than transpose(shape_cast) The pattern `ConvertIllegalShapeCastOpsToTransposes` that is specific to transposes with scalable dimensions reverses the canonicalization added here, so I've I've disabled this canonicalization for scalable vectors
Handles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into. For example
can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that
The pattern
ConvertIllegalShapeCastOpsToTransposes
that is specific to transposes with scalable dimensions reverses the canonicalization added here, so I've I've disabled this canonicalization for scalable vectors