-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][vector] Refactor createWriteOrMaskedWrite
#138137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][vector] Refactor createWriteOrMaskedWrite
#138137
Conversation
This patch updates `createWriteOrMaskedWrite` to make it consistent with `createReadOrMaskedRead`. Before diving into the details: note that these utilities are currently implemented in different files — "VectorUtils.cpp" (Vector) and "Vectorization.cpp" (Linalg). In a subsequent patch, I plan to move `createWriteOrMaskedWrite` into "VectorUtils.cpp". SUMMARY OF CHANGES: The main change is to remove the logic that creates the destination tensor, which previously looked like: ```cpp Value dest = builder.create<tensor::EmptyOp>(loc, destSizes, inputType.getElementType()); ``` With this patch, createWriteOrMaskedWrite now simply generates: ```mlir %res = vector.transfer_write %vectorToStore into %dest ``` This replaces the previous form: ```mlir %dest = tensor.empty(%destSizes) %res = vector.transfer_write %vectorToStore into %dest ``` In other words, the destination value `%dest` is now passed as an input parameter. This makes `createWriteOrMaskedWrite` re-usable in contexts where the destination tensor is already known — for example, in `vectorizeAsInsertSliceOp`, which I will update in a follow-up patch. OTHER CHANGES: * Added comments and clarified TODOs. * Updated tests: since destination sizes are now computed independently inside `createWriteOrMaskedWrite`, some additional `tensor.dim` ops appear. These will be cleaned up by CSE + canonicalization.
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThis patch updates Before diving into the details: note that these utilities are currently SUMMARY OF CHANGES: The main change is to remove the logic that creates the destination Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
inputType.getElementType()); With this patch, createWriteOrMaskedWrite now simply generates: %res = vector.transfer_write %vectorToStore into %dest This replaces the previous form: %dest = tensor.empty(%destSizes)
%res = vector.transfer_write %vectorToStore into %dest In other words, the destination value OTHER CHANGES:
Full diff: https://github.com/llvm/llvm-project/pull/138137.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a477c2fb3f8cb..12ecdf9494bef 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1506,72 +1506,68 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
}
-/// Creates a TransferWriteOp to write `input` into a newly initialized
-/// output tensor.
+/// Creates an optionally masked TransferWriteOp
///
-/// Given:
-/// - an input vector to write,
-/// - the mixed destination sizes for the output tensor,
-/// - and the vector sizes used for vectorization (i.e., the leading N dims,
-/// for some value of N),
-///
-/// this function generates the following sequence of ops:
-///
-/// %dest = tensor.empty(%destSizes)
-/// %res = vector.transfer_write %input into %dest
+/// Generates the following operation:
+/// %res = vector.transfer_write %vectorToStore into %dest
///
/// If the leading N dimensions of the destination tensor do not match
-/// `inputVecSizesForLeadingDims` (where N =
-/// rank(`inputVecSizesForLeadingDims`)), masking is applied to ensure
-/// correctness:
+/// `inputVecSizesForLeadingDims`, where=
+/// * N = rank(`inputVecSizesForLeadingDims`)),
+/// masking is applied to ensure correctness:
///
-/// %dest = tensor.empty(%destSizes)
-/// %write = vector.transfer_write %input into %dest
-/// %mask = vector.create_mask(%destSizes)
+/// %write = vector.transfer_write %vectorToStore into %dest
+/// %mask = vector.create_mask(%destShape)
/// %res = vector.mask %mask { %write }
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
-/// %dest = tensor.empty(%destSizes)
+/// %write = vector.transfer_write %vectorToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
-/// NOTE: all write offsets are set to 0.
+/// NOTE: All write offsets are set to 0.
+/// TODO: Allow specyfying write offsets.
/// NOTE: When N < rank(input), the missing vector sizes are effectively
/// extracted from the trailing sizes of `destSizes`. This means those sizes
-/// must be static. Supporting dynamic sizes will require the user to specify
-/// the remaining vector sizes. This is left as a TODO.
+/// must be static.
+/// TODO: Support cases where an arbitrary dim is dynamic - this will require
+/// specifying all the vector sizes.
static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input,
- SmallVector<OpFoldResult> destSizes,
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
+ Value dest,
ArrayRef<int64_t> inputVecSizesForLeadingDims,
bool useInBoundsInsteadOfMasking = false) {
- auto inputType = cast<VectorType>(input.getType());
- assert(inputType.getRank() == static_cast<int64_t>(destSizes.size()) &&
+ ShapedType destType = cast<ShapedType>(dest.getType());
+ assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
+ static_cast<int64_t>(destType.getRank()) &&
"Rank mismatch!");
- Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
- inputType.getElementType());
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
- auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
auto destShape = cast<ShapedType>(dest.getType()).getShape();
+
+ // Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(rank, true);
if (useInBoundsInsteadOfMasking) {
// In this case, assume that all the required vector sizes have been
// provided.
- assert(inputVecSizesForLeadingDims.size() == destSizes.size() &&
+ assert(inputVecSizesForLeadingDims.size() ==
+ static_cast<size_t>(destType.getRank()) &&
"Insufficient number of input vector sizes!");
// Update the inBounds attribute.
for (unsigned i = 0; i < rank; i++)
inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
!ShapedType::isDynamic(destShape[i]);
}
+
+ // Generate the xfer_write Op
+ auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Operation *write = builder.create<vector::TransferWriteOp>(
loc,
- /*vector=*/input,
+ /*vector=*/vectorToStore,
/*source=*/dest,
/*indices=*/SmallVector<Value>(rank, zero),
/*inBounds=*/inBoundsVal);
@@ -1579,11 +1575,17 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input,
destShape.drop_front(inputVecSizesForLeadingDims.size()),
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+
+ // If masking is disabled, exit.
if (useInBoundsInsteadOfMasking)
return write;
+
+ // Check if masking is needed.
bool needMaskForWrite =
!llvm::equal(inputVecSizesForLeadingDims,
destShape.take_front(inputVecSizesForLeadingDims.size()));
+
+ // If masking is needed, generate the mask and mask the operation.
if (needMaskForWrite) {
SmallVector<int64_t> writeMaskShape;
writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
@@ -1592,10 +1594,11 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input,
inputVecSizesForLeadingDims.size(),
destShape.end());
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
- Value maskForWrite =
- builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
+ Value maskForWrite = builder.create<vector::CreateMaskOp>(
+ loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
write = mlir::vector::maskOperation(builder, write, maskForWrite);
}
+
return write;
}
@@ -1693,9 +1696,11 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
loc, shapeCastOp.getResult(), destPermutation);
// Create TransferWriteOp.
+ Value dest = rewriter.create<tensor::EmptyOp>(
+ loc, reifiedReturnShapes[0],
+ transposeOp.getResult().getType().getElementType());
Operation *write =
- createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
- /*destSizes=*/reifiedReturnShapes[0],
+ createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
/*inputVecSizesForLeadingDims=*/inputVectorSizes,
/*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
@@ -1830,10 +1835,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
unpackOp.getDestType().hasStaticShape()
? vectorSizes
: shapeCastOp.getResultVectorType().getShape());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, shapeCastOp.getResult(), /*destSizes=*/reifiedRetShapes[0],
- /*inputVecSizesForLeadingDims=*/writeVectorSizes,
- useInBoundsInsteadOfMasking);
+ Value dest = rewriter.create<tensor::EmptyOp>(
+ loc, reifiedRetShapes[0],
+ shapeCastOp.getResult().getType().getElementType());
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
+ /*inputVecSizesForLeadingDims=*/writeVectorSizes,
+ useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1861,10 +1869,14 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
/*useInBoundsInsteadOfMasking=*/false);
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, maskedRead, reifiedReturnShapes[0],
- /*inputVecSizesForLeadingDims=*/inputVectorSizes,
- /*useInBoundsInsteadOfMasking=*/false);
+
+ // Create Xfer write Op
+ Value dest = rewriter.create<tensor::EmptyOp>(
+ loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
+ /*inputVecSizesForLeadingDims=*/inputVectorSizes,
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 299be1296aa66..6b760a15afd56 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -641,7 +641,9 @@ func.func @test_masked_vectorize_dynamic_pad(
// CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
// CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor<?x?xf32>
// CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index
- // CHECK: %[[mask_2:.*]] = vector.create_mask %[[res_d0]], %[[res_d1]] : vector<2x4xi1>
+ // CHECK-DAG: %[[d2:.*]] = tensor.dim %[[empty]], {{.*}} : tensor<?x?xf32>
+ // CHECK-DAG: %[[d3:.*]] = tensor.dim %[[empty]], {{.*}} : tensor<?x?xf32>
+ // CHECK: %[[mask_2:.*]] = vector.create_mask %[[d2]], %[[d3]] : vector<2x4xi1>
// CHECK: %[[masked_write:.*]] = vector.mask %[[mask_2]] {
// CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]]
// CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<?x?xf32>
@@ -800,7 +802,9 @@ func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[d0]], %[[d1]]) : tensor<?x?x16x2xf32>
-// CHECK: %[[mask_0:.*]] = vector.create_mask %[[d0]], %[[d1]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1>
+// CHECK-DAG: %[[d2:.*]] = tensor.dim %[[empty]], {{.*}} : tensor<?x?x16x2xf32>
+// CHECK-DAG: %[[d3:.*]] = tensor.dim %[[empty]], {{.*}} : tensor<?x?x16x2xf32>
+// CHECK: %[[mask_0:.*]] = vector.create_mask %[[d2]], %[[d3]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1>
// CHECK: %[[masked_write:.*]] = vector.mask %[[mask_0]] {
// CHECK-SAME: vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32>
|
This patch updates
createWriteOrMaskedWrite
to make it consistent withcreateReadOrMaskedRead
.Before diving into the details: note that these utilities are currently
implemented in different files — "VectorUtils.cpp" (Vector) and
"Vectorization.cpp" (Linalg). In a subsequent patch, I plan to move
createWriteOrMaskedWrite
into "VectorUtils.cpp".SUMMARY OF CHANGES:
The main change is to remove the logic that creates the destination
tensor, which previously looked like:
With this patch, createWriteOrMaskedWrite now simply generates:
This replaces the previous form:
In other words, the destination value
%dest
is now passed as an inputparameter. This makes
createWriteOrMaskedWrite
re-usable in contextswhere the destination tensor is already known — for example, in
vectorizeAsInsertSliceOp
, which I will update in a follow-up patch.OTHER CHANGES:
Added comments and clarified TODOs.
Updated tests: since destination sizes are now computed independently
inside
createWriteOrMaskedWrite
, some additionaltensor.dim
opsappear. These will be cleaned up by CSE + canonicalization.