Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[mlir][vector] Canonicalize/fold 'order preserving' transposes #135841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 1, 2025

Conversation

newling
Copy link
Contributor

@newling newling commented Apr 15, 2025

Handles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into. For example

%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x1x1x6xf32> to vector<1x4x6x1xf32>

can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that

  1. shape_cast is more canonical than shape_cast(transpose)
  2. shape_cast is more canonical than transpose(shape_cast)

The pattern ConvertIllegalShapeCastOpsToTransposes that is specific to transposes with scalable dimensions reverses the canonicalization added here, so I've I've disabled this canonicalization for scalable vectors

@llvmbot
Copy link
Member

llvmbot commented Apr 16, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

Handles special case where transpose doesn't permute any non-1 dimensions (and so is effectively a shape_cast) and is adjacent to a shape_cast that it can fold into. For example

%1 = vector.transpose %0, [1, 0, 3, 2] : vector&lt;4x1x1x6xf32&gt; to vector&lt;1x4x6x1xf32&gt;

can be folded into an adjacent shape_cast. An alternative to this PR would be to canonicalize such transposes to shape_casts directly, but I think it'll be difficult getting consensus that shape_cast is 'more canonical' than transpose, so this PR compromises with the less opinionated claim that

  1. shape_cast is more canonical than shape_cast(transpose)
  2. shape_cast is more canonical than transpose(shape_cast)

The pattern ConvertIllegalShapeCastOpsToTransposes that is specific to transposes with scalable dimensions reverses the canonicalization added here, so I've I've disabled this canonicalization for scalable vectors


Full diff: https://github.com/llvm/llvm-project/pull/135841.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+79-9)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+64)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..5da0ef0af032f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5621,6 +5621,29 @@ LogicalResult ShapeCastOp::verify() {
   return success();
 }
 
+namespace {
+
+/// Return true if `transpose` does not permute a pair of dimensions that are
+/// both not of size 1. By `order preserving` we mean that the flattened
+/// versions of the input and output vectors are (numerically) identical.
+/// In other words `transpose` is effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose) {
+  ArrayRef<int64_t> permutation = transpose.getPermutation();
+  ArrayRef<int64_t> inShape = transpose.getSourceVectorType().getShape();
+  int64_t current = 0;
+  for (auto p : permutation) {
+    if (inShape[p] != 1) {
+      if (p < current) {
+        return false;
+      }
+      current = p;
+    }
+  }
+  return true;
+}
+
+} // namespace
+
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // No-op shape cast.
@@ -5629,13 +5652,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
 
-  // Canceling shape casts.
-  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-
-    // Only allows valid transitive folding (expand/collapse dimensions).
-    VectorType srcType = otherOp.getSource().getType();
+  // shape_cast(something(x)) -> x, or
+  //                          -> shape_cast(x).
+  //
+  // Confirms that a new shape_cast will have valid semantics (expands OR
+  // collapses dimensions).
+  auto maybeFold = [&](TypedValue<VectorType> source) -> OpFoldResult {
+    VectorType srcType = source.getType();
     if (resultType == srcType)
-      return otherOp.getSource();
+      return source;
     if (srcType.getRank() < resultType.getRank()) {
       if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
         return {};
@@ -5645,8 +5670,25 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
     } else {
       return {};
     }
-    setOperand(otherOp.getSource());
+    setOperand(source);
     return getResult();
+  };
+
+  // Canceling shape casts.
+  if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
+    TypedValue<VectorType> source = otherOp.getSource();
+    return maybeFold(source);
+  }
+
+  // shape_cast(transpose(x)) -> shape_cast(x)
+  if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
+    if (transpose.getType().isScalable())
+      return {};
+    if (isOrderPreserving(transpose)) {
+      TypedValue<VectorType> source = transpose.getVector();
+      return maybeFold(source);
+    }
+    return {};
   }
 
   // Cancelling broadcast and shape cast ops.
@@ -5675,7 +5717,7 @@ namespace {
 /// Helper function that computes a new vector type based on the input vector
 /// type by removing the trailing one dims:
 ///
-///   vector<4x1x1xi1> --> vector<4x1>
+///   vector<4x1x1xi1> --> vector<4x1xi1>
 ///
 static VectorType trimTrailingOneDims(VectorType oldType) {
   ArrayRef<int64_t> oldShape = oldType.getShape();
@@ -6161,12 +6203,40 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(shape_cast) into a new shape_cast.
+class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto shapeCastOp =
+        transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
+    if (!shapeCastOp)
+      return failure();
+    if (!isOrderPreserving(transposeOp))
+      return failure();
+    if (transposeOp.getType().isScalable())
+      return failure();
+
+    VectorType resultType = transposeOp.getType();
+
+    // We don't need to check isValidShapeCast at this point, because it is
+    // guaranteed that merging the transpose into the the shape_cast is a valid
+    // shape_cast, because the transpose just inserts/removes ones.
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
+                                                     shapeCastOp.getSource());
+    return success();
+  }
+};
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat>(context);
+              FoldTransposeShapeCast, TransposeFolder, FoldTransposeSplat>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..10144cb9034e4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3295,3 +3295,67 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
   %res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
   return %res : vector<4x1xi32>
 }
+
+// -----
+
+// In this test, the permutation maps the non-one dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// CHECK-LABEL: @transpose_shape_cast
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<1x4x4x1x1xi8> to vector<4x4xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @transpose_shape_cast(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8> {
+  %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+     : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
+  %1 = vector.shape_cast %0 : vector<4x1x1x1x4xi8> to vector<4x4xi8>
+  return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// In this test, the mapping of non-one indices (1 and 2) is as follows:
+// 1 -> 2
+// 2 -> 1
+// As this is not increasing (2 > 1), this transpose is not order
+// preserving and cannot be treated as a shape_cast.
+// CHECK-LABEL: @negative_transpose_shape_cast
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+//       CHECK:   %[[TRANSPOSE:.*]] = vector.transpose %[[ARG]]
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSPOSE]]
+//       CHECK:   return %[[SHAPE_CAST]] : vector<4x4xi8>
+func.func @negative_transpose_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<4x4xi8> {
+  %0 = vector.transpose %arg, [0, 2, 1, 3]
+     : vector<1x4x4x1xi8> to vector<1x4x4x1xi8>
+  %1 = vector.shape_cast %0 : vector<1x4x4x1xi8> to vector<4x4xi8>
+  return %1 : vector<4x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @shape_cast_transpose
+//  CHECK-SAME:   %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<2x3x1x1xi8> to vector<6x1x1xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<6x1x1xi8>
+func.func @shape_cast_transpose(%arg : vector<2x3x1x1xi8>) ->  vector<6x1x1xi8> {
+  %0 = vector.shape_cast %arg : vector<2x3x1x1xi8> to vector<6x1x1xi8>
+  %1 = vector.transpose %0, [0, 2, 1]
+     : vector<6x1x1xi8> to vector<6x1x1xi8>
+  return %1 : vector<6x1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_shape_cast_transpose
+//  CHECK-SAME:   %[[ARG:.*]]: vector<6xi8>) -> vector<2x3xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//       CHECK:   %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]]
+//       CHECK:   return %[[TRANSPOSE]] : vector<2x3xi8>
+func.func @negative_shape_cast_transpose(%arg : vector<6xi8>) -> vector<2x3xi8> {
+  %0 = vector.shape_cast %arg : vector<6xi8> to vector<3x2xi8>
+  %1 = vector.transpose %0, [1, 0] : vector<3x2xi8> to vector<2x3xi8>
+  return %1 : vector<2x3xi8>
+}

@newling newling force-pushed the fold_transpose_shape_cast branch from b21a4a6 to f4ae206 Compare April 21, 2025 17:29
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, makes sense!

Quite possibly there’s no need to be so strict with "scalable" vectors (thanks for checking!) — let me experiment and confirm, probably tomorrow.

Btw:

  • "not of size 1" vs "non-one dimensions" vs "non-one indices"

We tend to say "non-unit dims" 😉 I hope I'm not being overly pedantic by suggesting using that wording in this PR as well. (Also, ESL here — so feel free to suggest something else if you prefer!)

@newling
Copy link
Contributor Author

newling commented Apr 28, 2025

Quite possibly there’s no need to be so strict with "scalable" vectors (thanks for checking!) — let me experiment and confirm, probably tomorrow.

Thanks!

We tend to say "non-unit dims" 😉 I hope I'm not being overly pedantic by suggesting using that wording in this PR as well. (Also, ESL here — so feel free to suggest something else if you prefer!)

This is useful info, I don't want to create unnecessary fragmentation. Will update this PR.

@banach-space
Copy link
Contributor

Quite possibly there’s no need to be so strict with "scalable" vectors (thanks for checking!) — let me experiment and confirm, probably tomorrow.

Thanks!

I was able to run all end-to-end tests targeting scalable vectors using your patch (after commenting out both instances of if (transpose.getType().isScalable())). Since they all passed, I think it’s safe to remove that guard.

Could you add some dedicated tests for scalable vectors? Also, you’ll want to disable these transformations when encountering scalable unit dims, i.e., [1], since 1 != [1]. The presence of scalable unit dims usually points to bigger modeling issues, so I wouldn’t worry about fully solving them here - just make sure they’re tested accordingly.

Lastly, would you mind adding some labels (or block comments) to "vector-transpose.mlir" to separate the tests for FoldTransposeShapeCastOp from those for ShapeCast::fold?

Thanks again!

@newling
Copy link
Contributor Author

newling commented Apr 30, 2025

Thanks @banach-space I'll address these comments, but just to mention that there still is test that fails (hangs) for me, it is test/Dialect/ArmSME/vector-legalization.mlir

The problem is that the pattern ConvertIllegalShapeCastOpsToTransposes converts

%0 = shape_cast [4]x1 -> [4]

into

%0 = transpose [4]x1 -> 1x[4]
%1 = shape_cast 1x[4] -> [4]

but the canonicalizer added here does the exact reverse! So we get an infinite loop.

The place where this transformation happens:

I've tried to dig a bit deeper, but don't see where shape_cast gets further lowered (I was wondering if shape_cast [4]x1 -> [4] could be lowered as it is without the need for ). Also wondered if maybe shape_cast [4]x1 -> [4] can't always be fused with its defining op?

@banach-space
Copy link
Contributor

Thanks @banach-space I'll address these comments, but just to mention that there still is test that fails (hangs) for me, it is test/Dialect/ArmSME/vector-legalization.mlir

In that case, I think the right approach is to first remove ConvertIllegalShapeCastOpsToTransposes (I suspect it's no longer needed), and then we can relax this logic.

Sorry for the back-and-forth, but let’s keep things as-is for now. Would you mind leaving a TODO for me to revisit and relax these constraints? Ideally, reference ConvertIllegalShapeCastOpsToTransposes in the comment.

Also, would you mind replacing failure() with notifyMatchFailure()? Thanks!

@newling
Copy link
Contributor Author

newling commented May 1, 2025

Thanks @banach-space I'll address these comments, but just to mention that there still is test that fails (hangs) for me, it is test/Dialect/ArmSME/vector-legalization.mlir

In that case, I think the right approach is to first remove ConvertIllegalShapeCastOpsToTransposes (I suspect it's no longer needed), and then we can relax this logic.

Sorry for the back-and-forth, but let’s keep things as-is for now. Would you mind leaving a TODO for me to revisit and relax these constraints? Ideally, reference ConvertIllegalShapeCastOpsToTransposes in the comment.

Also, would you mind replacing failure() with notifyMatchFailure()? Thanks!

I think this approach makes sense. I've made these updates (and am about to push them). I realized that only one of the 2 patterns needs to have scalable vectors disabled, so I've enabled the other one (i.e. transpose(shape_cast) -> shape_cast can handle scalable vectors fine, I didn't need to fail in this case too).
Thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made one final request - otherwise, this LGTM. Thanks!
(Approving as-is, but please address that final comment.)

By the way, I’ll be away next week. Absolutely no need to wait for me to review or approve anything - just letting you know so you don’t feel blocked by a lack of reviews. And I have a feeling you'll be sending more patches soon anyway :)

ArrayRef<int64_t> inShape = transpose.getSourceVectorType().getShape();
int64_t current = 0;
for (auto p : permutation) {
if (inShape[p] != 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic should also check the scalable flags - otherwise [1] will be treated as 1. Basically, we should disable it for things like this:

   %1 = vector.transpose %0, [0, 2, 1]
     : vector<6x1x[1]xi8> to vector<6x[1]x1xi8>

We are very unlikely to ever generate this, but it's better to be safe than sorry.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, sorry you mentioned that earlier too. Will fix this.

@newling
Copy link
Contributor Author

newling commented May 1, 2025

By the way, I’ll be away next week. Absolutely no need to wait for me to review or approve anything - just letting you know so you don’t feel blocked by a lack of reviews. And I have a feeling you'll be sending more patches soon anyway :)

Ok thanks for letting me know. And thank you so much for all the reviewing you've been doing on my PRs :)

@newling newling force-pushed the fold_transpose_shape_cast branch from ec66724 to 11fd6c3 Compare May 1, 2025 20:25
@newling newling merged commit 1a44f38 into llvm:main May 1, 2025
11 checks passed
@@ -5575,6 +5575,34 @@ LogicalResult ShapeCastOp::verify() {
return success();
}

namespace {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I intend to change this in a follow-up PR.

// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
// vectors, so by disabling this folder for scalable vectors the
// cycle is avoided.
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why we generate illegal shape cast ops in first place? It sounds like something that shouldn't happen...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're illegal according to an arm specific lowering target, which I'm not familiar with

struct ConvertIllegalShapeCastOpsToTransposes

I think @banach-space suspects that actually they're not illegal and will investigate the removal of this constraint here (and the pattern ConvertIllegalShapeCastOpsToTransposes).

IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…135841)

Handles special case where transpose doesn't permute any non-1
dimensions (and so is effectively a shape_cast) and is adjacent to a
shape_cast that it can fold into. For example

```
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x1x1x6xf32> to vector<1x4x6x1xf32>
```

can be folded into an adjacent shape_cast. An alternative to this PR
would be to canonicalize such transposes to shape_casts directly, but I
think it'll be difficult getting consensus that shape_cast is 'more
canonical' than transpose, so this PR compromises with the less
opinionated claim that

1) shape_cast is more canonical than shape_cast(transpose)
2) shape_cast is more canonical than transpose(shape_cast)

The pattern `ConvertIllegalShapeCastOpsToTransposes` that is specific to
transposes with scalable dimensions reverses the canonicalization added
here, so I've I've disabled this canonicalization for scalable vectors
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…135841)

Handles special case where transpose doesn't permute any non-1
dimensions (and so is effectively a shape_cast) and is adjacent to a
shape_cast that it can fold into. For example

```
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x1x1x6xf32> to vector<1x4x6x1xf32>
```

can be folded into an adjacent shape_cast. An alternative to this PR
would be to canonicalize such transposes to shape_casts directly, but I
think it'll be difficult getting consensus that shape_cast is 'more
canonical' than transpose, so this PR compromises with the less
opinionated claim that

1) shape_cast is more canonical than shape_cast(transpose)
2) shape_cast is more canonical than transpose(shape_cast)

The pattern `ConvertIllegalShapeCastOpsToTransposes` that is specific to
transposes with scalable dimensions reverses the canonicalization added
here, so I've I've disabled this canonicalization for scalable vectors
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…135841)

Handles special case where transpose doesn't permute any non-1
dimensions (and so is effectively a shape_cast) and is adjacent to a
shape_cast that it can fold into. For example

```
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x1x1x6xf32> to vector<1x4x6x1xf32>
```

can be folded into an adjacent shape_cast. An alternative to this PR
would be to canonicalize such transposes to shape_casts directly, but I
think it'll be difficult getting consensus that shape_cast is 'more
canonical' than transpose, so this PR compromises with the less
opinionated claim that

1) shape_cast is more canonical than shape_cast(transpose)
2) shape_cast is more canonical than transpose(shape_cast)

The pattern `ConvertIllegalShapeCastOpsToTransposes` that is specific to
transposes with scalable dimensions reverses the canonicalization added
here, so I've I've disabled this canonicalization for scalable vectors
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
…135841)

Handles special case where transpose doesn't permute any non-1
dimensions (and so is effectively a shape_cast) and is adjacent to a
shape_cast that it can fold into. For example

```
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x1x1x6xf32> to vector<1x4x6x1xf32>
```

can be folded into an adjacent shape_cast. An alternative to this PR
would be to canonicalize such transposes to shape_casts directly, but I
think it'll be difficult getting consensus that shape_cast is 'more
canonical' than transpose, so this PR compromises with the less
opinionated claim that

1) shape_cast is more canonical than shape_cast(transpose)
2) shape_cast is more canonical than transpose(shape_cast)

The pattern `ConvertIllegalShapeCastOpsToTransposes` that is specific to
transposes with scalable dimensions reverses the canonicalization added
here, so I've I've disabled this canonicalization for scalable vectors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants