Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 46a779c

Browse files
authored
Add support for bufferizing LinalgExt ops. (iree-org#6377)
Since LinalgExtInterface is a subset of LinalgInterface, we can use template in convertAnyLinalgOp. analyseLinalg*Ops function has different implementation because we don't define indexing maps in LinalgExtOp. Also adds a interface method -- clone. This is a step towards iree-org#6154
1 parent 8b5da81 commit 46a779c

File tree

6 files changed

+82
-7
lines changed

6 files changed

+82
-7
lines changed

‎iree/compiler/Codegen/Common/BUILD‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ cc_library(
5050
"//iree/compiler/Dialect/Flow/IR",
5151
"//iree/compiler/Dialect/HAL/IR",
5252
"//iree/compiler/Dialect/IREE/IR",
53+
"//iree/compiler/Dialect/LinalgExt/IR",
5354
"//iree/compiler/Dialect/Shape/IR",
5455
"@llvm-project//llvm:Support",
5556
"@llvm-project//mlir:Affine",

‎iree/compiler/Codegen/Common/CMakeLists.txt‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ iree_cc_library(
5858
iree::compiler::Dialect::Flow::IR
5959
iree::compiler::Dialect::HAL::IR
6060
iree::compiler::Dialect::IREE::IR
61+
iree::compiler::Dialect::LinalgExt::IR
6162
iree::compiler::Dialect::Shape::IR
6263
PUBLIC
6364
)

‎iree/compiler/Codegen/Common/LinalgBufferizePass.cpp‎

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
4747
#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
4848
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
49+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
4950
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
5051
#include "llvm/ADT/EquivalenceClasses.h"
5152
#include "llvm/ADT/TypeSwitch.h"
@@ -380,6 +381,26 @@ static SmallVector<Value> getTiedOperandsForLinalgOps(
380381
return tiedOperands;
381382
}
382383

384+
static LogicalResult analyseLinalgExtOps(linalg_ext::LinalgExtOp op,
385+
BufferizationPlan &plan) {
386+
if (!op.hasTensorSemantics()) return success();
387+
// TODO(hanchung): Revisit if we can tie together op.getOutputOperands() with
388+
// the corresponding op.getInputOperands(). For now we have limit LinalgExt
389+
// ops, and there is no use case. So we ignore it.
390+
// Note: this is what should be done for LinalgOps, except for a what is done
391+
// for operand fusion today.
392+
for (auto input : op.getInputOperands()) {
393+
plan.insert(input->get());
394+
}
395+
for (auto output : op.getOutputOperands()) {
396+
plan.insert(output->get());
397+
}
398+
for (auto result : op->getResults()) {
399+
plan.insert(result);
400+
}
401+
return success();
402+
}
403+
383404
/// Adds the corresponding `outs` and result tensors of the linalg op into the
384405
/// same equivalence class.
385406
static LogicalResult analyseLinalgOps(linalg::LinalgOp linalgOp,
@@ -580,6 +601,10 @@ static LogicalResult analyseOperations(FuncOp funcOp, BufferizationPlan &plan) {
580601
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
581602
return analyseLinalgOps(linalgOp, plan);
582603
})
604+
.Case<linalg_ext::LinalgExtOp>(
605+
[&](linalg_ext::LinalgExtOp linalgExtOp) {
606+
return analyseLinalgExtOps(linalgExtOp, plan);
607+
})
583608
.Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
584609
[&](auto reshapeOp) {
585610
return analyseSingleOperandResultOp(reshapeOp.src(),
@@ -910,7 +935,8 @@ static Value getInplaceResultBuffer(OpBuilder &b, OpResult resultValue,
910935
resultBuffer =
911936
TypeSwitch<Operation *, Value>(op)
912937
.Case<scf::IfOp, scf::ForOp, linalg::LinalgOp,
913-
tensor::InsertSliceOp, vector::TransferWriteOp>(
938+
linalg_ext::LinalgExtOp, tensor::InsertSliceOp,
939+
vector::TransferWriteOp>(
914940
[&](auto op) { return resultBuffer; })
915941
.Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
916942
[&](auto reshapeOp) {
@@ -1123,9 +1149,10 @@ static LogicalResult getOrAllocateResultBuffers(
11231149
/// Generic conversion pattern that matches any linalg::LinalgOp. This avoids
11241150
/// template instantiating one pattern for each linalg::LinalgOp. The method
11251151
/// expects all operands and results have already been mapped to memrefs.
1152+
template <typename OpTy>
11261153
static LogicalResult convertAnyLinalgOp(
1127-
OpBuilder &b, linalg::LinalgOp op, BlockAndValueMapping &bvm,
1128-
BufferizationPlan &plan, WorkgroupMemoryAllocationFn allocationFn) {
1154+
OpBuilder &b, OpTy op, BlockAndValueMapping &bvm, BufferizationPlan &plan,
1155+
WorkgroupMemoryAllocationFn allocationFn) {
11291156
// Skip linalg ops inserted by this pass.
11301157
if (op.hasBufferSemantics()) return success();
11311158

@@ -1539,12 +1566,12 @@ void LinalgBufferizePass::runOnOperation() {
15391566
}
15401567
return convertPadTensorOp(b, padTensorOp, bvm);
15411568
})
1542-
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
1543-
if (failed(getOrAllocateResultBuffers(b, linalgOp.getOperation(), bvm,
1544-
plan, allocationFn))) {
1569+
.Case<linalg::LinalgOp, linalg_ext::LinalgExtOp>([&](auto op) {
1570+
if (failed(
1571+
getOrAllocateResultBuffers(b, op, bvm, plan, allocationFn))) {
15451572
return failure();
15461573
}
1547-
return convertAnyLinalgOp(b, linalgOp, bvm, plan, allocationFn);
1574+
return convertAnyLinalgOp(b, op, bvm, plan, allocationFn);
15481575
})
15491576
.Case<tensor::InsertSliceOp>(
15501577
[&](tensor::InsertSliceOp subTensorInsertOp) {

‎iree/compiler/Codegen/Common/test/linalg_bufferize.mlir‎

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2384,3 +2384,23 @@ hal.interface @io attributes {sym_visibility = "private"} {
23842384
// CHECK: scf.if
23852385
// CHECK-DAG: memref.store %[[V1]], %[[INOUT]][%[[P1]]]
23862386
// CHECK-DAG: memref.store %[[V2]], %[[INOUT]][%[[ARG1]]]
2387+
2388+
// -----
2389+
2390+
func @linalg_ext_sort_1d() {
2391+
%c0 = constant 0 : index
2392+
%0 = hal.interface.binding.subspan @io::@rw[%c0] : !flow.dispatch.tensor<readwrite:128xi32>
2393+
%1 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readwrite:128xi32> -> tensor<128xi32>
2394+
%2 = linalg_ext.sort {dimension = 0 : i64} outs(%1 : tensor<128xi32>) {
2395+
^bb0(%arg0: i32, %arg1: i32): // no predecessors
2396+
%3 = cmpi sgt, %arg0, %arg1 : i32
2397+
linalg_ext.yield %3 : i1
2398+
} -> tensor<128xi32>
2399+
flow.dispatch.tensor.store %2, %0, offsets = [], sizes = [], strides = [] : tensor<128xi32> -> !flow.dispatch.tensor<readwrite:128xi32>
2400+
return
2401+
}
2402+
// CHECK-LABEL: func @linalg_ext_sort_1d()
2403+
// CHECK-DAG: %[[INOUT:.+]] = hal.interface.binding.subspan @io::@rw
2404+
// CHECK: linalg_ext.sort
2405+
// CHECK-SAME: dimension = 0 : i64
2406+
// CHECK-SAME: outs(%[[INOUT]] : memref<128xi32>)

‎iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
88
#define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
99

10+
#include "mlir/IR/BlockAndValueMapping.h"
11+
#include "mlir/IR/Builders.h"
1012
#include "mlir/IR/BuiltinTypes.h"
1113
#include "mlir/IR/OpDefinition.h"
1214
#include "mlir/Support/LLVM.h"

‎iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td‎

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,30 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
438438
return opOperand->get().getType().template isa<RankedTensorType>();
439439
});
440440
}]
441+
>,
442+
//===------------------------------------------------------------------===//
443+
// Other static interface methods.
444+
//===------------------------------------------------------------------===//
445+
InterfaceMethod<
446+
/*desc=*/[{
447+
Clone the current operation with the given location and operands. This
448+
is used to abstract away the optional underlying region creation. This
449+
does not change the balance between input, output_buffer and
450+
init_tensors operands.
451+
}],
452+
/*retTy=*/"Operation *",
453+
/*methodName=*/"clone",
454+
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
455+
"ValueRange":$operands),
456+
[{
457+
BlockAndValueMapping bvm;
458+
OperationState state(
459+
loc, ConcreteOp::getOperationName(), operands, resultTypes,
460+
$_op->getAttrs());
461+
for (Region &r : $_op->getRegions())
462+
r.cloneInto(state.addRegion(), bvm);
463+
return b.createOperation(state);
464+
}]
441465
>
442466
];
443467

0 commit comments

Comments
 (0)