Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9ac1f34

Browse files
[Linalg] Fix bug in control function logic of push down extract pattern (llvm#158348)
Current logic just bails out if the first extract producer fails the control function, this PR fixes that. Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 120d747 commit 9ac1f34

File tree

3 files changed

+60
-15
lines changed

3 files changed

+60
-15
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,21 +1245,21 @@ struct SliceDimInfo {
12451245
OpFoldResult outputSize;
12461246
};
12471247

1248-
/// Return the first input extract slice operand, if present, for the current
1248+
/// Return all extract slice operands, if present, for the current
12491249
/// generic op.
1250-
static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
1251-
OpOperand *sliceOperand = nullptr;
1250+
static FailureOr<SmallVector<OpOperand *>>
1251+
getSliceOperands(GenericOp genericOp) {
1252+
SmallVector<OpOperand *> sliceOperands;
12521253
for (auto operand : genericOp.getDpsInputOperands()) {
12531254
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
12541255
if (!extractOp)
12551256
continue;
1256-
sliceOperand = operand;
1257-
break;
1257+
sliceOperands.push_back(operand);
12581258
}
1259-
if (!sliceOperand) {
1259+
if (sliceOperands.empty()) {
12601260
return failure();
12611261
}
1262-
return sliceOperand;
1262+
return sliceOperands;
12631263
}
12641264

12651265
// Return a map of dims that have partial slices on them so that other operands
@@ -1336,14 +1336,24 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
13361336
genericOp,
13371337
"propagation through generic with gather semantics is unsupported.");
13381338
// Collect the sliced operand, if present.
1339-
auto maybeSliceOperand = getSliceOperand(genericOp);
1340-
if (failed(maybeSliceOperand))
1339+
auto maybeSliceOperands = getSliceOperands(genericOp);
1340+
if (failed(maybeSliceOperands))
13411341
return failure();
1342-
OpOperand *sliceOperand = *maybeSliceOperand;
1343-
unsigned OperandIndex = sliceOperand->getOperandNumber();
1344-
1345-
if (!controlFn(sliceOperand))
1342+
SmallVector<OpOperand *> sliceOperands = *maybeSliceOperands;
1343+
OpOperand *sliceOperand;
1344+
1345+
bool foundValidOperand = false;
1346+
for (auto currSliceOperand : sliceOperands) {
1347+
if (controlFn(currSliceOperand)) {
1348+
sliceOperand = currSliceOperand;
1349+
foundValidOperand = true;
1350+
break;
1351+
}
1352+
}
1353+
if (!foundValidOperand) {
13461354
return failure();
1355+
}
1356+
unsigned OperandIndex = sliceOperand->getOperandNumber();
13471357

13481358
tensor::ExtractSliceOp producerSliceOp =
13491359
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,3 +1577,33 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32>
15771577
// CHECK: %[[GENERIC:.+]] = linalg.generic
15781578
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]]
15791579
// CHECK: return %[[EXTRACT]]
1580+
1581+
// -----
1582+
// Test that if one extract doesnt pass the control function which in this case is set to
1583+
// only allow extracts from the same block, then an extract from a later operand can still be pushed
1584+
// down.
1585+
func.func @push_extract_through_generic_secondextract(%arg0: tensor<128x128xf32>, %arg1: tensor<?x?xbf16>, %arg2: index) -> tensor<?x?xbf16> {
1586+
%c0 = arith.constant 0 : index
1587+
%c32 = arith.constant 32 : index
1588+
%extracted_slice1 = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
1589+
%for = scf.for %arg3 = %c0 to %c32 step %arg2 iter_args(%arg4 = %arg1) -> tensor<?x?xbf16> {
1590+
%extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
1591+
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0, d1)> ,affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice1, %extracted_slice : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg1 : tensor<?x?xbf16>) {
1592+
^bb0(%in: f32, %in_1 : f32, %out: bf16):
1593+
%1 = arith.truncf %in : f32 to bf16
1594+
linalg.yield %1 : bf16
1595+
} -> tensor<?x?xbf16>
1596+
scf.yield %0 : tensor<?x?xbf16>
1597+
}
1598+
return %for : tensor<?x?xbf16>
1599+
}
1600+
1601+
// CHECK-LABEL: func.func @push_extract_through_generic_secondextract
1602+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1603+
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
1604+
// CHECK: %[[FOR:.+]] = scf.for
1605+
// CHECK: %[[PAD:.+]] = tensor.pad %[[EXTRACT]]
1606+
// CHECK: %[[GENERIC:.+]] = linalg.generic
1607+
// CHECK-SAME: ins(%[[PAD]], %[[ARG0]]
1608+
// CHECK: %[[EXTRACT2:.+]] = tensor.extract_slice %[[GENERIC]]
1609+
// CHECK: scf.yield %[[EXTRACT2]]

mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,13 @@ struct TestDataLayoutPropagationPass
3434
RewritePatternSet patterns(context);
3535
linalg::populateDataLayoutPropagationPatterns(
3636
patterns, [](OpOperand *opOperand) { return true; });
37-
linalg::populateExtractSliceSinkingPatterns(
38-
patterns, [](OpOperand *opOperand) { return true; });
37+
linalg::ControlPropagationFn controlExtract =
38+
[](OpOperand *opOperand) -> bool {
39+
Operation *producer = opOperand->get().getDefiningOp();
40+
Operation *consumer = opOperand->getOwner();
41+
return consumer->getBlock() == producer->getBlock();
42+
};
43+
linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
3944
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
4045
return signalPassFailure();
4146
}

0 commit comments

Comments
 (0)