Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[mlir][vector] Linearization: push 'bit width' logic out of patterns #136581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 30, 2025

Conversation

newling
Copy link
Contributor

@newling newling commented Apr 21, 2025

[NFC]

Vector linearization is a collection of rewrite patterns that reduce the rank of vector operands and results.

In #83314 an option to ignore (make 'legal') operations with large inner-most dimensions was added. This current PR is a step towards making that option live outside of upstream MLIR. The motivation is to remove non-core functionality (I would like to use this pass, but would prefer not to deal with 'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the targetVectorBitWidth move the relevant code (now in mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code bases, and then remove it from upstream. In addition the tests need to split out (I've intentionally not modified the lit tests here, to make it easier to confirm that this is a NFC). I'm very happy to help make it easier to do this final step!

The approach I've used is to move the logic pertaining to targetVectorBitWidth out the patterns, and into the conversion target, which the end user can control outside of core MLIR.

Copy link

github-actions bot commented Apr 21, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented Apr 21, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

Vector linearization is a collection of rewrite patterns that reduce the rank of vector operands and results of operations.

In #83314 an option to ignore (legalize) operations with large inner-most dimensions was added. This current PR is a step towards making that option live outside of upstream MLIR. The motivation is to reduce non-core functionality (I would like to use this pass, but would prefer not to deal with 'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the targetVectorBitWidth move legalBecauseOfBitwidth to their code bases, and then remove it from upstream.

The approach I've used is to move the logic pertaining to targetVectorBitWidth out the patterns, and into the conversion target, which the end user can control outside of core MLIR.


Patch is 26.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136581.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+18-12)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+130-134)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+13)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+9-8)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index ce97847172197..d9a0791cdea33 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -392,18 +392,24 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Populates patterns for ND vectors (N >= 2) linearization and sets up the
-/// provided ConversionTarget with the appropriate legality configuration for
-/// the ops to get converted properly.
-void populateVectorLinearizeTypeConversionsAndLegality(
-    TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
-
-/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
-/// vector shuffle operations.
-void populateVectorLinearizeShuffleLikeOpsPatterns(
-    const TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth);
+/// Populate `typeConverter` and `conversionTarget` with the definition of
+/// legal types and operations, for the specific case where vectors with
+/// trailing dimensions of size greater than `targetBitWidth` are legal.
+void populateVectorLinearizeBitWidthTargetAndConverter(
+    TypeConverter &typeConverter, ConversionTarget &conversionTarget,
+    unsigned targetBitWidth);
+
+/// Populates `patterns` for ND vector (N >= 2) linearization. Patterns for
+/// converting ConstantLike, Vectorizable, and vector::BitCast.
+void populateVectorLinearizeBasePatterns(const TypeConverter &,
+                                         RewritePatternSet &patterns,
+                                         const ConversionTarget &);
+
+/// Populates `patterns` for linearizing ND (N >= 2) vector operations
+/// to 1D vector shuffle operations.
+void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
+                                                   RewritePatternSet &patterns,
+                                                   const ConversionTarget &);
 
 } // namespace vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..e24c8ee961c51 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -10,7 +10,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -22,44 +21,16 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/ArrayRef.h"
 #include <cstdint>
+#include <limits>
 #include <numeric>
+#include <optional>
 
 using namespace mlir;
 
-static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
-  auto resultTypes = op->getResultTypes();
-  for (auto resType : resultTypes) {
-    VectorType vecType = dyn_cast<VectorType>(resType);
-    // Reject index since getElementTypeBitWidth will abort for Index types.
-    if (!vecType || vecType.getElementType().isIndex())
-      return false;
-    // There are no dimension to fold if it is a 0-D vector.
-    if (vecType.getRank() == 0)
-      return false;
-    unsigned trailingVecDimBitWidth =
-        vecType.getShape().back() * vecType.getElementTypeBitWidth();
-    if (trailingVecDimBitWidth >= targetBitWidth)
-      return false;
-  }
-  return true;
-}
-
-static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
-  VectorType vecType = dyn_cast<VectorType>(t);
-  // Reject index since getElementTypeBitWidth will abort for Index types.
-  if (!vecType || vecType.getElementType().isIndex())
-    return false;
-  // There are no dimension to fold if it is a 0-D vector.
-  if (vecType.getRank() == 0)
-    return false;
-  unsigned trailingVecDimBitWidth =
-      vecType.getShape().back() * vecType.getElementTypeBitWidth();
-  return trailingVecDimBitWidth <= targetBitWidth;
-}
-
 static FailureOr<Attribute>
 linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
                    VectorType resType, Attribute value) {
+
   if (auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
     if (resType.isScalable() && !isa<SplatElementsAttr>(value))
       return rewriter.notifyMatchFailure(
@@ -76,16 +47,14 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
 }
 
 namespace {
+
 struct LinearizeConstantLike final
     : OpTraitConversionPattern<OpTrait::ConstantLike> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
-  LinearizeConstantLike(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpTraitConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+  LinearizeConstantLike(const TypeConverter &typeConverter,
+                        MLIRContext *context, PatternBenefit benefit = 1)
+      : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
@@ -100,10 +69,6 @@ struct LinearizeConstantLike final
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
 
-    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          loc, "Can't flatten since targetBitWidth <= OpSize");
-
     StringAttr attrName = rewriter.getStringAttr("value");
     Attribute value = op->getAttr(attrName);
     if (!value)
@@ -124,9 +89,6 @@ struct LinearizeConstantLike final
     rewriter.replaceOp(op, newOp);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 struct LinearizeVectorizable final
@@ -134,18 +96,12 @@ struct LinearizeVectorizable final
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
 public:
-  LinearizeVectorizable(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpTraitConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+  LinearizeVectorizable(const TypeConverter &typeConverter,
+                        MLIRContext *context, PatternBenefit benefit = 1)
+      : OpTraitConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
     FailureOr<Operation *> newOp =
         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
     if (failed(newOp))
@@ -154,9 +110,6 @@ struct LinearizeVectorizable final
     rewriter.replaceOp(op, (*newOp)->getResults());
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
@@ -173,12 +126,10 @@ struct LinearizeVectorizable final
 struct LinearizeVectorExtractStridedSlice final
     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
   using OpConversionPattern::OpConversionPattern;
-  LinearizeVectorExtractStridedSlice(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+  LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
+                                     MLIRContext *context,
+                                     PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
@@ -189,9 +140,6 @@ struct LinearizeVectorExtractStridedSlice final
     if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
     ArrayAttr offsets = extractOp.getOffsets();
     ArrayAttr sizes = extractOp.getSizes();
@@ -268,9 +216,6 @@ struct LinearizeVectorExtractStridedSlice final
         extractOp, dstType, srcVector, srcVector, indices);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ShuffleOp that works on nD (n > 1)
@@ -291,8 +236,7 @@ struct LinearizeVectorShuffle final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
@@ -302,13 +246,12 @@ struct LinearizeVectorShuffle final
     assert(dstType && "vector type destination expected.");
     // The assert is used because vector.shuffle does not support scalable
     // vectors.
-    assert(!(shuffleOp.getV1VectorType().isScalable() ||
-             shuffleOp.getV2VectorType().isScalable() ||
-             dstType.isScalable()) &&
-           "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
+    bool scalable = shuffleOp.getV1VectorType().isScalable() ||
+                    shuffleOp.getV2VectorType().isScalable() ||
+                    dstType.isScalable();
+    if (scalable)
+      return rewriter.notifyMatchFailure(shuffleOp,
+                                         "scalable vectors are not supported.");
 
     Value vec1 = adaptor.getV1();
     Value vec2 = adaptor.getV2();
@@ -343,9 +286,6 @@ struct LinearizeVectorShuffle final
                                                    vec2, indices);
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ExtractOp to a ShuffleOp that works on a
@@ -364,8 +304,7 @@ struct LinearizeVectorExtract final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -378,9 +317,6 @@ struct LinearizeVectorExtract final
         cast<VectorType>(dstTy).isScalable())
       return rewriter.notifyMatchFailure(extractOp,
                                          "scalable vectors are not supported.");
-    if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          extractOp, "Can't flatten since targetBitWidth <= OpSize");
 
     // Dynamic position is not supported.
     if (extractOp.hasDynamicPosition())
@@ -405,9 +341,6 @@ struct LinearizeVectorExtract final
 
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the InsertOp to a ShuffleOp that works on a
@@ -427,8 +360,7 @@ struct LinearizeVectorInsert final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -439,11 +371,6 @@ struct LinearizeVectorInsert final
       return rewriter.notifyMatchFailure(insertOp,
                                          "scalable vectors are not supported.");
 
-    if (!isLessThanOrEqualTargetBitWidth(insertOp.getValueToStoreType(),
-                                         targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          insertOp, "Can't flatten since targetBitWidth < OpSize");
-
     // dynamic position is not supported
     if (insertOp.hasDynamicPosition())
       return rewriter.notifyMatchFailure(insertOp,
@@ -471,11 +398,11 @@ struct LinearizeVectorInsert final
     }
 
     llvm::SmallVector<int64_t, 2> indices(dstSize);
-    auto origValsUntil = indices.begin();
+    auto *origValsUntil = indices.begin();
     std::advance(origValsUntil, linearizedOffset);
     std::iota(indices.begin(), origValsUntil,
               0); // original values that remain [0, offset)
-    auto newValsUntil = origValsUntil;
+    auto *newValsUntil = origValsUntil;
     std::advance(newValsUntil, srcSize);
     std::iota(origValsUntil, newValsUntil,
               dstSize); // new values [offset, offset+srcNumElements)
@@ -488,9 +415,6 @@ struct LinearizeVectorInsert final
 
     return success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the BitCastOp that works on nD (n > 1)
@@ -508,8 +432,7 @@ struct LinearizeVectorBitCast final
       const TypeConverter &typeConverter, MLIRContext *context,
       unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
       PatternBenefit benefit = 1)
-      : OpConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
+      : OpConversionPattern(typeConverter, context, benefit) {}
   LogicalResult
   matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -518,24 +441,103 @@ struct LinearizeVectorBitCast final
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type.");
 
-    if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          loc, "Can't flatten since targetBitWidth <= OpSize");
-
     rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
                                                    adaptor.getSource());
     return mlir::success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 } // namespace
 
-void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
-    TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth) {
+/// If `type` is VectorType with trailing dimension of (bit) size greater than
+/// or equal to `targetBitWidth`, its defining op is considered legal.
+static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {
+
+  VectorType vecType = dyn_cast<VectorType>(type);
+
+  if (!vecType)
+    return true;
+
+  // The width of the type 'index' is unbounded (and therefore potentially above
+  // the target width).
+  if (vecType.getElementType().isIndex())
+    return true;
+
+  unsigned finalDimSize =
+      vecType.getRank() == 0 ? 0 : vecType.getShape().back();
+
+  unsigned trailingVecDimBitWidth =
+      finalDimSize * vecType.getElementTypeBitWidth();
+
+  return trailingVecDimBitWidth >= targetBitWidth;
+}
+
+static SmallVector<std::pair<Type, unsigned>>
+getChecksForBitwidth(Operation *op, unsigned targetBitWidth) {
+
+  if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+    auto w = targetBitWidth < std::numeric_limits<unsigned>::max()
+                 ? targetBitWidth + 1
+                 : targetBitWidth;
+    return {{insertOp.getValueToStoreType(), w}};
+  }
+  auto resultTypes = op->getResultTypes();
+  SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+  resultsWithBitWidth.reserve(resultTypes.size());
+  for (Type type : resultTypes) {
+    resultsWithBitWidth.push_back({type, targetBitWidth});
+  }
+  return resultsWithBitWidth;
+}
+
+/// Return true if the operation `op` does not support scalable vectors and
+/// has at least 1 scalable vector result.
+static bool legalBecauseScalable(Operation *op) {
+
+  bool scalableSupported = op->hasTrait<OpTrait::ConstantLike>() ||
+                           op->hasTrait<OpTrait::Vectorizable>() ||
+                           isa<vector::BitCastOp>(op);
+
+  if (scalableSupported)
+    return false;
+
+  // Check if any of the results is a scalable vector type.
+  auto types = op->getResultTypes();
+  bool containsScalableResult =
+      std::any_of(types.begin(), types.end(), [](Type type) {
+        auto vecType = dyn_cast<VectorType>(type);
+        return vecType && vecType.isScalable();
+      });
+
+  return containsScalableResult;
+}
+
+static bool dynamicallyLegal(Operation *op, unsigned targetBitWidth) {
+
+  // Only ops that are in the vector dialect, are ConstantLike, or
+  // are Vectorizable might be linearized currently, so legalize the others.
+  bool opIsVectorDialect = op->getDialect()->getNamespace() ==
+                           vector::VectorDialect::getDialectNamespace();
+  if (!opIsVectorDialect && !op->hasTrait<OpTrait::ConstantLike>() &&
+      !op->hasTrait<OpTrait::Vectorizable>())
+    return true;
+
+  // Some ops will not be linearized if they have scalable vector results.
+  if (legalBecauseScalable(op))
+    return true;
+
+  // Check on bitwidths.
+  auto typesToCheck = getChecksForBitwidth(op, targetBitWidth);
+  return std::any_of(typesToCheck.begin(), typesToCheck.end(),
+                     [&](std::pair<Type, unsigned> typeWidth) {
+                       return legalBecauseOfBitwidth(typeWidth.first,
+                                                     typeWidth.second);
+                     });
+}
+
+void mlir::vector::populateVectorLinearizeBitWidthTargetAndConverter(
+    TypeConverter &typeConverter, ConversionTarget &target,
+    unsigned targetBitWidth) {
 
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     if (!isLinearizableVector(type))
@@ -550,40 +552,34 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
         !isa<VectorType>(type))
       return nullptr;
-
     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
   };
+
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
+
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<vector::BitCastOp>(op) ||
-             op->hasTrait<OpTrait::ConstantLike>() ||
-             op->hasTrait<OpTrait::Vectorizable>())) {
-          return (isLessThanTargetBitWidth(op, targetBitWidth)
-                      ? typeConverter.isLegal(op)
-                      : true);
-        }
-        return std::nullopt;
+        bool isDynamicallyLegal = dynamicallyLegal(op, targetBitWidth);
+        if (isDynamicallyLegal)
+          return true;
+
+        bool shapeUnchanged = typeConverter.isLegal(op);
+        return shapeUnchanged;
       });
+}
 
+void mlir::vector::populateVectorLinearizeBasePatterns(
+    const TypeC...
[truncated]

@newling newling changed the title [vector][linearize] Refactor code to push target bit width out of core code [vector][linearize] Refactor code to push target bit width out of patterns Apr 21, 2025
Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but give it day or two for other vector folks to have a chance to chime in

@banach-space
Copy link
Contributor

Overall makes sense. In fact, this is a nice clean-up, thanks!

As a follow-up to this PR, I propose that user(s) of the targetVectorBitWidth move legalBecauseOfBitwidth to their code bases, and then remove it from upstream.

How are we going to "drive" the existing tests that depend on this? (through e.g. test-vector-linearize=target-vector-bitwidth=128).

@@ -550,40 +552,34 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
!isa<VectorType>(type))
return nullptr;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I find empty lines quite helpful 😅

Comment on lines 497 to 499
bool scalableSupported = op->hasTrait<OpTrait::ConstantLike>() ||
op->hasTrait<OpTrait::Vectorizable>() ||
isa<vector::BitCastOp>(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am quite confused by this logic ... I don't see anything specific to scalable vectors here?

Also, perhaps flip the logic in this method:

/// Return true if the operation op does not support scalable vectors

This would make more sense to me:

/// Return true if the operation op supports scalable vectors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just trying to do a pure NFC refactor, and these were the only ops that support scalable vectors. Flipping the logic is a good idea, will make this clearer.

@newling
Copy link
Contributor Author

newling commented Apr 22, 2025

Overall makes sense. In fact, this is a nice clean-up, thanks!

As a follow-up to this PR, I propose that user(s) of the targetVectorBitWidth move legalBecauseOfBitwidth to their code bases, and then remove it from upstream.

How are we going to "drive" the existing tests that depend on this? (through e.g. test-vector-linearize=target-vector-bitwidth=128).

Great point. I will try refactor the testing to make it an easier lift, and wait for the folks who use this to give their thoughts

@newling newling changed the title [vector][linearize] Refactor code to push target bit width out of patterns [mlir][vector] Linearization: push 'bit width' logic out of patterns Apr 22, 2025
@newling
Copy link
Contributor Author

newling commented Apr 22, 2025

Overall makes sense. In fact, this is a nice clean-up, thanks!

As a follow-up to this PR, I propose that user(s) of the targetVectorBitWidth move legalBecauseOfBitwidth to their code bases, and then remove it from upstream.

How are we going to "drive" the existing tests that depend on this? (through e.g. test-vector-linearize=target-vector-bitwidth=128).

Great point. I will try refactor the testing to make it an easier lift, and wait for the folks who use this to give their thoughts

I've factorized out a bit deeper now, it's clearer to me now what the lift out of llvm-project would look like

Some feedback from the users of the bit width logic (@bviyer @dcaballe @nbpatel ?) would be great!

@newling newling merged commit bad8bf5 into llvm:main Apr 30, 2025
11 checks passed
@banach-space
Copy link
Contributor

Sorry for the delay with this, I was waiting for either other reviewers to comment or for a "ping", I wasn't expecting this to be merged in the meantime 😅 Feel free to ping me directly if things go stale.

I've left a couple of comments - these can be addressed in a separate PR (this change is overall good, thanks!). I'm just still a bit confused about the logic around "scalable" vectors. Since you've looked at it recently, do you know the answer? (see my question inline)

: public PassWrapper<TestVectorLinearize, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
// TODO: move this code into the user project.
namespace vendor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is vendor and why "user project"?

return containsScalableResult;
}

static bool isNotLinearizable(Operation *op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this method be documented? What does it mean to be "non-linearizable"?

Also, I'm confused about isNotLinearizableBecauseScalable. Is linearization disabled in the presence of scalable vectors?

@newling
Copy link
Contributor Author

newling commented Apr 30, 2025

Sorry for the delay with this, I was waiting for either other reviewers to comment or for a "ping", I wasn't expecting this to be merged in the meantime 😅 Feel free to ping me directly if things go stale.

I've left a couple of comments - these can be addressed in a separate PR (this change is overall good, thanks!). I'm just still a bit confused about the logic around "scalable" vectors. Since you've looked at it recently, do you know the answer? (see my question inline)

Apologies @banach-space , I assumed you were fine with committing as is. I'll 'ping' the PR next time. In response to your questions (I'll also post a PR improving the docs):

What does it mean to be "non-linearizable"?

Here it just means "we don't have a pattern to linearize it yet". Or equivalently, that the ConversionTarget should consider it legal, and so we should not fail if it is not converted by a pattern.

Is linearization disabled in the presence of scalable vectors?

Yes, but only for some ops. Other ops have implementations for linearizing scalable vectors.

I think some of the confusion (and complexity) here could be avoided if linearization was implemented as just a bunch of patterns to append to a RewritePatternSet, like most of the functions in VectorRewritePatterns.h. I'm not sure why it was implemented as a legalization (i.e. with a ConversionTarget and a TypeConverter) in #81159 ( @Hardcode84 ? )

Also if anyone is opposed to this PR please let me know and I'll revert it (ping : @dcaballe @bviyer )

@Hardcode84
Copy link
Contributor

The motivation to use dialect conversion instead of just greedy rewriter was so we can reuse existing generic scf/cf/func conversions to convert vector types through control flow ops.

@newling
Copy link
Contributor Author

newling commented Apr 30, 2025

The motivation to use dialect conversion instead of just greedy rewriter was so we can reuse existing generic scf/cf/func conversions to convert vector types through control flow ops.

Can you please give a bit more information, I'm not familiar with this. A pointer to an example of where this is done would be useful. Is this specific to linearization, or do you think it should be approach for the other pattern-adding APIs in VectorRewritePatterns ?

@Hardcode84
Copy link
Contributor

There are generic conversions like populateSCFStructuralTypeConversionsAndLegality https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h#L52 which will update scf loop-carried vars, yields, etc according to the type converter. So if we want to linearize vector types used in scf.for for example we don't need to write scf.for-specific patterns for linearization.

@newling
Copy link
Contributor Author

newling commented Apr 30, 2025

Got it, thanks for the clear explanation

@nbpatel
Copy link
Contributor

nbpatel commented May 1, 2025

hey I still see bitwidth in some patterns, https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp#L295C7-L295C44 ...did you revert the change? or was it not merged correctly, or why is it still there for some patterns and not for others? sorry I thought its being removed completely from here

@newling
Copy link
Contributor Author

newling commented May 1, 2025

hey I still see bitwidth in some patterns, https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp#L295C7-L295C44 ...did you revert the change? or was it not merged correctly, or why is it still there for some patterns and not for others? sorry I thought its being removed completely from here

Oh no sorry, I just messed up. I thought I had removed them all, let me try and remove those ones too. Not sure how I missed that.

@newling
Copy link
Contributor Author

newling commented May 1, 2025

@nbpatel this removes the remaining bitwidths #138072 . Thanks for flagging this to me!

@banach-space
Copy link
Contributor

Apologies @banach-space , I assumed you were fine with committing as is. I'll 'ping' the PR next time.

No worries at all!

There was already one approval, and technically that’s all that’s required —-per “How a Patch Is Accepted”:

Only approval from a single reviewer is required

Personally, I tend to ping folks who’ve left comments (but haven’t explicitly approved) before merging, just to be safe - but that’s more my own habit than an LLVM policy 😅

newling added a commit that referenced this pull request May 1, 2025
…#138072)

In #136581 the bitwidth logic
was supposed to be completely removed from the linearization patterns.
But it was left in a few places. This PR removes the remainders (they
were default valued constructor arguments that were unused).
newling added a commit that referenced this pull request May 1, 2025
In #136581 the logic pertaining
to bitwidth was removed from the patterns. This PR further factorizes
bitwidth logic out of the main test file.

The number of tests with bitwidth (in the new file added in this PR) is
now lower than before this PR. This is because this PR only tests the
bitwidth specific logic once (there was a fair amount of redundant
testing before).

I didn't do this test refactoring in
#136581 because I wanted to
make it clear that it was NFC by leaving the tests unchanged there
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…lvm#136581)

[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.

In llvm#83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!

The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…llvm#138072)

In llvm#136581 the bitwidth logic
was supposed to be completely removed from the linearization patterns.
But it was left in a few places. This PR removes the remainders (they
were default valued constructor arguments that were unused).
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
In llvm#136581 the logic pertaining
to bitwidth was removed from the patterns. This PR further factorizes
bitwidth logic out of the main test file.

The number of tests with bitwidth (in the new file added in this PR) is
now lower than before this PR. This is because this PR only tests the
bitwidth specific logic once (there was a fair amount of redundant
testing before).

I didn't do this test refactoring in
llvm#136581 because I wanted to
make it clear that it was NFC by leaving the tests unchanged there
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…lvm#136581)

[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.

In llvm#83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!

The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…llvm#138072)

In llvm#136581 the bitwidth logic
was supposed to be completely removed from the linearization patterns.
But it was left in a few places. This PR removes the remainders (they
were default valued constructor arguments that were unused).
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
In llvm#136581 the logic pertaining
to bitwidth was removed from the patterns. This PR further factorizes
bitwidth logic out of the main test file.

The number of tests with bitwidth (in the new file added in this PR) is
now lower than before this PR. This is because this PR only tests the
bitwidth specific logic once (there was a fair amount of redundant
testing before).

I didn't do this test refactoring in
llvm#136581 because I wanted to
make it clear that it was NFC by leaving the tests unchanged there
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…lvm#136581)

[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.

In llvm#83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!

The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…llvm#138072)

In llvm#136581 the bitwidth logic
was supposed to be completely removed from the linearization patterns.
But it was left in a few places. This PR removes the remainders (they
were default valued constructor arguments that were unused).
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
In llvm#136581 the logic pertaining
to bitwidth was removed from the patterns. This PR further factorizes
bitwidth logic out of the main test file.

The number of tests with bitwidth (in the new file added in this PR) is
now lower than before this PR. This is because this PR only tests the
bitwidth specific logic once (there was a fair amount of redundant
testing before).

I didn't do this test refactoring in
llvm#136581 because I wanted to
make it clear that it was NFC by leaving the tests unchanged there
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request May 6, 2025
…on patterns (#138072)

In llvm/llvm-project#136581 the bitwidth logic
was supposed to be completely removed from the linearization patterns.
But it was left in a few places. This PR removes the remainders (they
were default valued constructor arguments that were unused).
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request May 6, 2025
In llvm/llvm-project#136581 the logic pertaining
to bitwidth was removed from the patterns. This PR further factorizes
bitwidth logic out of the main test file.

The number of tests with bitwidth (in the new file added in this PR) is
now lower than before this PR. This is because this PR only tests the
bitwidth specific logic once (there was a fair amount of redundant
testing before).

I didn't do this test refactoring in
llvm/llvm-project#136581 because I wanted to
make it clear that it was NFC by leaving the tests unchanged there
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
…lvm#136581)

[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.

In llvm#83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!

The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
…llvm#138072)

In llvm#136581 the bitwidth logic
was supposed to be completely removed from the linearization patterns.
But it was left in a few places. This PR removes the remainders (they
were default valued constructor arguments that were unused).
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
In llvm#136581 the logic pertaining
to bitwidth was removed from the patterns. This PR further factorizes
bitwidth logic out of the main test file.

The number of tests with bitwidth (in the new file added in this PR) is
now lower than before this PR. This is because this PR only tests the
bitwidth specific logic once (there was a fair amount of redundant
testing before).

I didn't do this test refactoring in
llvm#136581 because I wanted to
make it clear that it was NFC by leaving the tests unchanged there
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants