Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
push further with the separation of concerns
  • Loading branch information
newling committed Apr 22, 2025
commit be48849486b1c1ae68568dee941acc2bc7d49951
Original file line number Diff line number Diff line change
Expand Up @@ -392,24 +392,29 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);

/// Populate `typeConverter` and `conversionTarget` with the definition of
/// legal types and operations, for the specific case where vectors with
/// trailing dimensions of size greater than `targetBitWidth` are legal.
void populateVectorLinearizeBitWidthTargetAndConverter(
TypeConverter &typeConverter, ConversionTarget &conversionTarget,
unsigned targetBitWidth);

/// Populates `patterns` for ND vector (N >= 2) linearization. Patterns for
/// converting ConstantLike, Vectorizable, and vector::BitCast.
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
/// This registers (1) which operations are legal and hence should not be
/// linearized, (2) what converted types are (rank-1 vectors) and how to
/// materialze the conversion (with shape_cast)
///
/// Note: the set of legal operations can be extended by a user if for example
/// certain rank>1 vectors are considered valid, but adding additional
/// dynamically legal ops to `conversionTarget`.
void populateForVectorLinearize(TypeConverter &typeConverter,
ConversionTarget &conversionTarget);

/// Populates `patterns` for ND vector (N >= 2) linearization. This currently
/// contains patterns for converting ConstantLike, Vectorizable, and
/// vector::BitCast ops.
void populateVectorLinearizeBasePatterns(const TypeConverter &,
RewritePatternSet &patterns,
const ConversionTarget &);
const ConversionTarget &,
RewritePatternSet &patterns);

/// Populates `patterns` for linearizing ND (N >= 2) vector operations
/// to 1D vector shuffle operations.
void populateVectorLinearizeShuffleLikeOpsPatterns(const TypeConverter &,
RewritePatternSet &patterns,
const ConversionTarget &);
const ConversionTarget &,
RewritePatternSet &patterns);

} // namespace vector
} // namespace mlir
Expand Down
167 changes: 52 additions & 115 deletions mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ struct LinearizeConstantLike final
if (op->getNumResults() != 1)
return rewriter.notifyMatchFailure(loc, "expected 1 result");

const TypeConverter &converter = *getTypeConverter();
const TypeConverter &typeConverter = *getTypeConverter();
auto resType =
converter.convertType<VectorType>(op->getResult(0).getType());

if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
typeConverter.convertType<VectorType>(op->getResult(0).getType());
assert(resType && "expected 1-D vector type");

StringAttr attrName = rewriter.getStringAttr("value");
Attribute value = op->getAttr(attrName);
Expand All @@ -80,7 +78,7 @@ struct LinearizeConstantLike final
return failure();

FailureOr<Operation *> convertResult =
convertOpResultTypes(op, /*operands=*/{}, converter, rewriter);
convertOpResultTypes(op, /*operands=*/{}, typeConverter, rewriter);
if (failed(convertResult))
return failure();

Expand Down Expand Up @@ -244,14 +242,6 @@ struct LinearizeVectorShuffle final
VectorType dstType =
getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
assert(dstType && "vector type destination expected.");
// The assert is used because vector.shuffle does not support scalable
// vectors.
bool scalable = shuffleOp.getV1VectorType().isScalable() ||
shuffleOp.getV2VectorType().isScalable() ||
dstType.isScalable();
if (scalable)
return rewriter.notifyMatchFailure(shuffleOp,
"scalable vectors are not supported.");

Value vec1 = adaptor.getV1();
Value vec2 = adaptor.getV2();
Expand All @@ -270,7 +260,7 @@ struct LinearizeVectorShuffle final
}

// For each value in the mask, we generate the indices of the source vectors
// that needs to be shuffled to the destination vector. If shuffleSliceLen >
// that need to be shuffled to the destination vector. If shuffleSliceLen >
// 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
// elements) instead of scalars.
ArrayRef<int64_t> mask = shuffleOp.getMask();
Expand Down Expand Up @@ -309,14 +299,7 @@ struct LinearizeVectorExtract final
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
if (!dstTy)
return rewriter.notifyMatchFailure(extractOp,
"expected n-D vector type.");

if (extractOp.getVector().getType().isScalable() ||
cast<VectorType>(dstTy).isScalable())
return rewriter.notifyMatchFailure(extractOp,
"scalable vectors are not supported.");
assert(dstTy && "expected 1-D vector type");

// Dynamic position is not supported.
if (extractOp.hasDynamicPosition())
Expand Down Expand Up @@ -367,9 +350,6 @@ struct LinearizeVectorInsert final
VectorType dstTy = getTypeConverter()->convertType<VectorType>(
insertOp.getDestVectorType());
assert(dstTy && "vector type destination expected.");
if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
return rewriter.notifyMatchFailure(insertOp,
"scalable vectors are not supported.");

// dynamic position is not supported
if (insertOp.hasDynamicPosition())
Expand Down Expand Up @@ -436,11 +416,8 @@ struct LinearizeVectorBitCast final
LogicalResult
matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = castOp.getLoc();
auto resType = getTypeConverter()->convertType(castOp.getType());
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type.");

assert(resType && "expected 1-D vector type");
rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
adaptor.getSource());
return mlir::success();
Expand All @@ -449,56 +426,15 @@ struct LinearizeVectorBitCast final

} // namespace

/// If `type` is VectorType with trailing dimension of (bit) size greater than
/// or equal to `targetBitWidth`, its defining op is considered legal.
static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {

VectorType vecType = dyn_cast<VectorType>(type);

if (!vecType)
return true;

// The width of the type 'index' is unbounded (and therefore potentially above
// the target width).
if (vecType.getElementType().isIndex())
return true;

unsigned finalDimSize =
vecType.getRank() == 0 ? 0 : vecType.getShape().back();

unsigned trailingVecDimBitWidth =
finalDimSize * vecType.getElementTypeBitWidth();

return trailingVecDimBitWidth >= targetBitWidth;
}

static SmallVector<std::pair<Type, unsigned>>
getChecksForBitwidth(Operation *op, unsigned targetBitWidth) {

if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
auto w = targetBitWidth < std::numeric_limits<unsigned>::max()
? targetBitWidth + 1
: targetBitWidth;
return {{insertOp.getValueToStoreType(), w}};
}
auto resultTypes = op->getResultTypes();
SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
resultsWithBitWidth.reserve(resultTypes.size());
for (Type type : resultTypes) {
resultsWithBitWidth.push_back({type, targetBitWidth});
}
return resultsWithBitWidth;
}

/// Return true if the operation `op` does not support scalable vectors and
/// has at least 1 scalable vector result.
static bool legalBecauseScalable(Operation *op) {

bool scalableSupported = op->hasTrait<OpTrait::ConstantLike>() ||
op->hasTrait<OpTrait::Vectorizable>() ||
isa<vector::BitCastOp>(op);

if (scalableSupported)
/// has at least 1 scalable vector result. These ops should all eventually
/// support scalable vectors, and this function should be removed.
static bool isNotLinearizableBecauseScalable(Operation *op) {

bool unsupported =
isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
op);
if (!unsupported)
return false;

// Check if any of the results is a scalable vector type.
Expand All @@ -512,73 +448,74 @@ static bool legalBecauseScalable(Operation *op) {
return containsScalableResult;
}

static bool dynamicallyLegal(Operation *op, unsigned targetBitWidth) {
static bool isNotLinearizable(Operation *op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this method be documented? What does it mean to be "non-linearizable"?

Also, I'm confused about isNotLinearizableBecauseScalable. Is linearization disabled in the presence of scalable vectors?


// Only ops that are in the vector dialect, are ConstantLike, or
// are Vectorizable might be linearized currently, so legalize the others.
bool opIsVectorDialect = op->getDialect()->getNamespace() ==
vector::VectorDialect::getDialectNamespace();
if (!opIsVectorDialect && !op->hasTrait<OpTrait::ConstantLike>() &&
!op->hasTrait<OpTrait::Vectorizable>())
// are Vectorizable might be linearized currently.
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
StringRef opDialect = op->getDialect()->getNamespace();
bool unsupported = (opDialect != vectorDialect) &&
!op->hasTrait<OpTrait::ConstantLike>() &&
!op->hasTrait<OpTrait::Vectorizable>();
if (unsupported)
return true;

// Some ops will not be linearized if they have scalable vector results.
if (legalBecauseScalable(op))
// Some ops currently don't support scalable vectors.
if (isNotLinearizableBecauseScalable(op))
return true;

// Check on bitwidths.
auto typesToCheck = getChecksForBitwidth(op, targetBitWidth);
return std::any_of(typesToCheck.begin(), typesToCheck.end(),
[&](std::pair<Type, unsigned> typeWidth) {
return legalBecauseOfBitwidth(typeWidth.first,
typeWidth.second);
});
return false;
}

void mlir::vector::populateVectorLinearizeBitWidthTargetAndConverter(
TypeConverter &typeConverter, ConversionTarget &target,
unsigned targetBitWidth) {
void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
ConversionTarget &target) {

typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
if (!isLinearizableVector(type))
auto convertType = [](Type type) -> std::optional<Type> {
VectorType vectorType = dyn_cast<VectorType>(type);
if (!vectorType || !isLinearizableVector(vectorType))
return type;

return VectorType::get(type.getNumElements(), type.getElementType(),
type.isScalable());
});
VectorType linearizedType =
VectorType::get(vectorType.getNumElements(),
vectorType.getElementType(), vectorType.isScalable());
return linearizedType;
};
typeConverter.addConversion(convertType);

auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
!isa<VectorType>(type))
if (inputs.size() != 1)
return nullptr;
return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
};

Value value = inputs.front();
if (!isa<VectorType>(type) || !isa<VectorType>(value.getType()))
return nullptr;

return builder.create<vector::ShapeCastOp>(loc, type, value);
};
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);

target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
bool isDynamicallyLegal = dynamicallyLegal(op, targetBitWidth);
if (isDynamicallyLegal)
if (isNotLinearizable(op))
return true;

bool shapeUnchanged = typeConverter.isLegal(op);
return shapeUnchanged;
// This will return true if, for all operand and result types `t`,
// convertType(t) = t. This is true if there are no rank>=2 vectors.
return typeConverter.isLegal(op);
});
}

void mlir::vector::populateVectorLinearizeBasePatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
const ConversionTarget &target) {
const TypeConverter &typeConverter, const ConversionTarget &target,
RewritePatternSet &patterns) {
patterns.add<LinearizeConstantLike, LinearizeVectorizable,
LinearizeVectorBitCast>(typeConverter, patterns.getContext());
}

void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
const ConversionTarget &target) {
const TypeConverter &typeConverter, const ConversionTarget &target,
RewritePatternSet &patterns) {
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
typeConverter, patterns.getContext());
Expand Down
7 changes: 4 additions & 3 deletions mlir/test/Dialect/Vector/linearize.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0

// RUN: mlir-opt %s -split-input-file -test-bit-width-contrained-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
// RUN: mlir-opt %s -split-input-file -test-bit-width-contrained-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0

// ALL-LABEL: test_linearize
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
Expand Down Expand Up @@ -97,7 +98,7 @@ func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>

// ALL-LABEL: test_index_no_linearize
func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
// ALL: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
// BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
return %0 : vector<2x2xindex>
}
Expand Down
Loading
Loading