Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'main' into vector_linearize_splat
  • Loading branch information
nbpatel committed May 1, 2025
commit fef6390eaf3abb4567dc955b9c087a69971037e2
167 changes: 73 additions & 94 deletions mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,6 @@

using namespace mlir;

constexpr unsigned defaultTargetVectorBitWidth =
std::numeric_limits<unsigned>::max();

static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
VectorType vecType = dyn_cast<VectorType>(resType);
// Reject index since getElementTypeBitWidth will abort for Index types.
if (!vecType || vecType.getElementType().isIndex())
return false;
// There are no dimension to fold if it is a 0-D vector.
if (vecType.getRank() == 0)
return false;
unsigned trailingVecDimBitWidth =
vecType.getShape().back() * vecType.getElementTypeBitWidth();
if (trailingVecDimBitWidth >= targetBitWidth)
return false;
}
return true;
}

static bool isLessThanOrEqualTargetBitWidth(Type t, unsigned targetBitWidth) {
VectorType vecType = dyn_cast<VectorType>(t);
// Reject index since getElementTypeBitWidth will abort for Index types.
if (!vecType || vecType.getElementType().isIndex())
return false;
// There are no dimension to fold if it is a 0-D vector.
if (vecType.getRank() == 0)
return false;
unsigned trailingVecDimBitWidth =
vecType.getShape().back() * vecType.getElementTypeBitWidth();
return trailingVecDimBitWidth <= targetBitWidth;
}

static FailureOr<Attribute>
linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
VectorType resType, Attribute value) {
Expand All @@ -85,12 +51,9 @@ struct LinearizeConstantLike final
: OpTraitConversionPattern<OpTrait::ConstantLike> {
using OpTraitConversionPattern::OpTraitConversionPattern;

LinearizeConstantLike(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
LinearizeConstantLike(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -130,12 +93,9 @@ struct LinearizeVectorizable final
using OpTraitConversionPattern::OpTraitConversionPattern;

public:
LinearizeVectorizable(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
LinearizeVectorizable(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpTraitConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -163,12 +123,10 @@ struct LinearizeVectorizable final
struct LinearizeVectorExtractStridedSlice final
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtractStridedSlice(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}

LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
Expand Down Expand Up @@ -271,12 +229,9 @@ struct LinearizeVectorExtractStridedSlice final
struct LinearizeVectorShuffle final
: public OpConversionPattern<vector::ShuffleOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeVectorShuffle(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
LinearizeVectorShuffle(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}

LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
Expand Down Expand Up @@ -332,12 +287,9 @@ struct LinearizeVectorShuffle final
struct LinearizeVectorExtract final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeVectorExtract(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
LinearizeVectorExtract(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -386,12 +338,9 @@ struct LinearizeVectorExtract final
struct LinearizeVectorInsert final
: public OpConversionPattern<vector::InsertOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeVectorInsert(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
LinearizeVectorInsert(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -456,12 +405,9 @@ struct LinearizeVectorInsert final
struct LinearizeVectorBitCast final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
LinearizeVectorBitCast(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
LinearizeVectorBitCast(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -486,12 +432,9 @@ struct LinearizeVectorSplat final
: public OpConversionPattern<vector::SplatOp> {
using OpConversionPattern::OpConversionPattern;

LinearizeVectorSplat(
const TypeConverter &typeConverter, MLIRContext *context,
unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
targetVectorBitWidth(targetVectBitWidth) {}
LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}

LogicalResult
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
Expand All @@ -515,9 +458,48 @@ struct LinearizeVectorSplat final
/// support scalable vectors, and this function should be removed.
static bool isNotLinearizableBecauseScalable(Operation *op) {

typeConverter.addConversion([](Type type) -> Type { return type; });
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
bool unsupported =
isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
op);
if (!unsupported)
return false;

// Check if any of the results is a scalable vector type.
auto types = op->getResultTypes();
bool containsScalableResult =
std::any_of(types.begin(), types.end(), [](Type type) {
auto vecType = dyn_cast<VectorType>(type);
return vecType && vecType.isScalable();
});

return containsScalableResult;
}

static bool isNotLinearizable(Operation *op) {

// Only ops that are in the vector dialect, are ConstantLike, or
// are Vectorizable might be linearized currently.
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
StringRef opDialect = op->getDialect()->getNamespace();
bool unsupported = (opDialect != vectorDialect) &&
!op->hasTrait<OpTrait::ConstantLike>() &&
!op->hasTrait<OpTrait::Vectorizable>();
if (unsupported)
return true;

// Some ops currently don't support scalable vectors.
if (isNotLinearizableBecauseScalable(op))
return true;

return false;
}

void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
ConversionTarget &target) {

auto convertType = [](Type type) -> std::optional<Type> {
VectorType vectorType = dyn_cast<VectorType>(type);
if (!vectorType || !isLinearizableVector(vectorType))
return type;

VectorType linearizedType =
Expand All @@ -543,14 +525,11 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {

target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
if ((isa<vector::BitCastOp, vector::SplatOp>(op) ||
op->hasTrait<OpTrait::ConstantLike>() ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
? typeConverter.isLegal(op)
: true);
}
return std::nullopt;
if (isNotLinearizable(op))
return true;
// This will return true if, for all operand and result types `t`,
// convertType(t) = t. This is true if there are no rank>=2 vectors.
return typeConverter.isLegal(op);
});
}

Expand All @@ -559,7 +538,7 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
LinearizeVectorBitCast, LinearizeVectorSplat>(
typeConverter, patterns.getContext(), targetBitWidth);
typeConverter, patterns.getContext());
}

void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Dialect/Vector/linearize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,20 @@ func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
%0 = vector.splat %arg0 : vector<4x2xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also work (i.e. scalable vector):

  %0 = vector.splat %arg0 : vector<4x[2]xi32>

Could you try and if it works, add a test? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the test

return %0 : vector<4x2xi32>
}

// -----
// ALL-LABEL: linearize_scalable_vector_splat
// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
// DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32>
// DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32>
// DEFAULT: return %[[CAST]] : vector<4x[2]xi32>
// BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32>
// BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32>
// BW-128: return %[[CAST]] : vector<4x[2]xi32>

// BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x[2]xi32>
// BW-0: return %[[SPLAT]] : vector<4x[2]xi32>
%0 = vector.splat %arg0 : vector<4x[2]xi32>
return %0 : vector<4x[2]xi32>
}
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.