|
46 | 46 | #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
|
47 | 47 | #include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
|
48 | 48 | #include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
|
| 49 | +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" |
49 | 50 | #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
|
50 | 51 | #include "llvm/ADT/EquivalenceClasses.h"
|
51 | 52 | #include "llvm/ADT/TypeSwitch.h"
|
@@ -380,6 +381,26 @@ static SmallVector<Value> getTiedOperandsForLinalgOps(
|
380 | 381 | return tiedOperands;
|
381 | 382 | }
|
382 | 383 |
|
| 384 | +static LogicalResult analyseLinalgExtOps(linalg_ext::LinalgExtOp op, |
| 385 | + BufferizationPlan &plan) { |
| 386 | + if (!op.hasTensorSemantics()) return success(); |
| 387 | + // TODO(hanchung): Revisit if we can tie together op.getOutputOperands() with |
| 388 | + // the corresponding op.getInputOperands(). For now we have limit LinalgExt |
| 389 | + // ops, and there is no use case. So we ignore it. |
| 390 | + // Note: this is what should be done for LinalgOps, except for a what is done |
| 391 | + // for operand fusion today. |
| 392 | + for (auto input : op.getInputOperands()) { |
| 393 | + plan.insert(input->get()); |
| 394 | + } |
| 395 | + for (auto output : op.getOutputOperands()) { |
| 396 | + plan.insert(output->get()); |
| 397 | + } |
| 398 | + for (auto result : op->getResults()) { |
| 399 | + plan.insert(result); |
| 400 | + } |
| 401 | + return success(); |
| 402 | +} |
| 403 | + |
383 | 404 | /// Adds the corresponding `outs` and result tensors of the linalg op into the
|
384 | 405 | /// same equivalence class.
|
385 | 406 | static LogicalResult analyseLinalgOps(linalg::LinalgOp linalgOp,
|
@@ -580,6 +601,10 @@ static LogicalResult analyseOperations(FuncOp funcOp, BufferizationPlan &plan) {
|
580 | 601 | .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
|
581 | 602 | return analyseLinalgOps(linalgOp, plan);
|
582 | 603 | })
|
| 604 | + .Case<linalg_ext::LinalgExtOp>( |
| 605 | + [&](linalg_ext::LinalgExtOp linalgExtOp) { |
| 606 | + return analyseLinalgExtOps(linalgExtOp, plan); |
| 607 | + }) |
583 | 608 | .Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
|
584 | 609 | [&](auto reshapeOp) {
|
585 | 610 | return analyseSingleOperandResultOp(reshapeOp.src(),
|
@@ -910,7 +935,8 @@ static Value getInplaceResultBuffer(OpBuilder &b, OpResult resultValue,
|
910 | 935 | resultBuffer =
|
911 | 936 | TypeSwitch<Operation *, Value>(op)
|
912 | 937 | .Case<scf::IfOp, scf::ForOp, linalg::LinalgOp,
|
913 |
| - tensor::InsertSliceOp, vector::TransferWriteOp>( |
| 938 | + linalg_ext::LinalgExtOp, tensor::InsertSliceOp, |
| 939 | + vector::TransferWriteOp>( |
914 | 940 | [&](auto op) { return resultBuffer; })
|
915 | 941 | .Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
|
916 | 942 | [&](auto reshapeOp) {
|
@@ -1123,9 +1149,10 @@ static LogicalResult getOrAllocateResultBuffers(
|
1123 | 1149 | /// Generic conversion pattern that matches any linalg::LinalgOp. This avoids
|
1124 | 1150 | /// template instantiating one pattern for each linalg::LinalgOp. The method
|
1125 | 1151 | /// expects all operands and results have already been mapped to memrefs.
|
| 1152 | +template <typename OpTy> |
1126 | 1153 | static LogicalResult convertAnyLinalgOp(
|
1127 |
| - OpBuilder &b, linalg::LinalgOp op, BlockAndValueMapping &bvm, |
1128 |
| - BufferizationPlan &plan, WorkgroupMemoryAllocationFn allocationFn) { |
| 1154 | + OpBuilder &b, OpTy op, BlockAndValueMapping &bvm, BufferizationPlan &plan, |
| 1155 | + WorkgroupMemoryAllocationFn allocationFn) { |
1129 | 1156 | // Skip linalg ops inserted by this pass.
|
1130 | 1157 | if (op.hasBufferSemantics()) return success();
|
1131 | 1158 |
|
@@ -1539,12 +1566,12 @@ void LinalgBufferizePass::runOnOperation() {
|
1539 | 1566 | }
|
1540 | 1567 | return convertPadTensorOp(b, padTensorOp, bvm);
|
1541 | 1568 | })
|
1542 |
| - .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) { |
1543 |
| - if (failed(getOrAllocateResultBuffers(b, linalgOp.getOperation(), bvm, |
1544 |
| - plan, allocationFn))) { |
| 1569 | + .Case<linalg::LinalgOp, linalg_ext::LinalgExtOp>([&](auto op) { |
| 1570 | + if (failed( |
| 1571 | + getOrAllocateResultBuffers(b, op, bvm, plan, allocationFn))) { |
1545 | 1572 | return failure();
|
1546 | 1573 | }
|
1547 |
| - return convertAnyLinalgOp(b, linalgOp, bvm, plan, allocationFn); |
| 1574 | + return convertAnyLinalgOp(b, op, bvm, plan, allocationFn); |
1548 | 1575 | })
|
1549 | 1576 | .Case<tensor::InsertSliceOp>(
|
1550 | 1577 | [&](tensor::InsertSliceOp subTensorInsertOp) {
|
|
0 commit comments