Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5eae05a

Browse files
authored
Merge branch 'google' into main-to-google
2 parents 75e4f33 + 0f26203 commit 5eae05a

31 files changed

+149
-139
lines changed

‎SUBMODULE_VERSIONS.txt‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
aa533abfd4232b01f9e57041d70114d5a77e6de0 third_party/googletest
55
88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
66
acd6f6f014c25e46363e718381e0b35205df2d83 third_party/libyaml
7-
5b8ddd2ccceb8de04bd020f286bc3ca38638ecb1 third_party/llvm-project
7+
ce211c505b82e5bbb68b936968d9b54608285416 third_party/llvm-project
88
8a46b64b269c6c8da865ccf25a4a221d6ae28fdc third_party/mlir-emitc
99
a41d23745eb902d7093ba6eaf4902c6ec8bf12b2 third_party/mlir-hlo
1010
4c7697dbe973ed01ae6fbec37d186ebd05982e1f third_party/pybind11
1111
2e1b5fb39ebc2ef4cb77005f8267e4f3a6241ba1 third_party/spirv_cross
1212
f5417a4b6633c3217c9a1bc2f0c70b1454975ba7 third_party/spirv_headers
1313
b42009b3b9d4ca35bc703f5310eedc74f584be58 third_party/stblib
14-
ab3db3801cc548803a01c2e7c6e2e36ff009005c third_party/tensorflow
14+
c134bdea4de1950c811dfbc365abae834be61ab7 third_party/tensorflow
1515
f03b677ffa0fd96fcf859c32e79b740fac7dd59e third_party/tracy
1616
9d10a96f2d57c3c37e167f2e73c9a31ac2e51fa5 third_party/vulkan_headers
1717
8d4a9e9174a9c6ad6a3a3ae981b915ef13fc12c4 third_party/vulkan_memory_allocator

‎integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class ConvertToMHLOPass : public PassWrapper<ConvertToMHLOPass, FunctionPass> {
9494
target.addLegalDialect<tensor::TensorDialect>();
9595
target.addLegalOp<mlir::CallOp>();
9696
target.addLegalOp<mlir::tensor::CastOp>();
97-
target.addLegalOp<mlir::memref::DimOp>();
97+
target.addLegalOp<mlir::tensor::DimOp>();
9898

9999
// TODO(suderman): Enable logicistic op for lowering once the op is
100100
// supported in IREE. Also, remove the numerically unstable ConvertSigmoidOp

‎iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
// CHECK-NEXT: %[[ARG1_DIM0:.+]] = hal.buffer_view.dim %[[ARG1]], 0 : index
1313
// CHECK-NEXT: %[[ARG1_TENSOR:.+]] = hal.tensor.cast %[[ARG1]] : !hal.buffer_view -> tensor<?x8x8x3xf32>{%[[ARG1_DIM0]]}
1414
// CHECK-NEXT: %[[RET_TENSOR:.+]]:2 = call @_dynamicEntry(%[[ARG0_TENSOR]], %[[ARG1_TENSOR]])
15-
// CHECK: %[[RET0_DIM0:.+]] = memref.dim %[[RET_TENSOR]]#0, %c0{{.*}} : tensor<?x8x8x3xf32>
15+
// CHECK: %[[RET0_DIM0:.+]] = tensor.dim %[[RET_TENSOR]]#0, %c0{{.*}} : tensor<?x8x8x3xf32>
1616
// CHECK-NEXT: %[[RET0_VIEW:.+]] = hal.tensor.cast %[[RET_TENSOR]]#0 : tensor<?x8x8x3xf32>{%[[RET0_DIM0]]} -> !hal.buffer_view
17-
// CHECK: %[[RET1_DIM0:.+]] = memref.dim %[[RET_TENSOR]]#1, %c0{{.*}} : tensor<?x8x8x3xf32>
17+
// CHECK: %[[RET1_DIM0:.+]] = tensor.dim %[[RET_TENSOR]]#1, %c0{{.*}} : tensor<?x8x8x3xf32>
1818
// CHECK-NEXT: %[[RET1_VIEW:.+]] = hal.tensor.cast %[[RET_TENSOR]]#1 : tensor<?x8x8x3xf32>{%[[RET1_DIM0]]} -> !hal.buffer_view
1919
// CHECK-NEXT: return %[[RET0_VIEW]], %[[RET1_VIEW]] : !hal.buffer_view, !hal.buffer_view
2020
// CHECK-NEXT: }

‎iree/compiler/Codegen/Common/LinalgBufferizePass.cpp‎

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ static LogicalResult analyseLinalgOps(linalg::LinalgOp linalgOp,
427427
static bool hasSingleRealUse(Value value) {
428428
int numUsers = 0;
429429
for (OpOperand &use : value.getUses()) {
430-
if (!isa<memref::DimOp>(use.getOwner())) {
430+
if (!isa<memref::DimOp, tensor::DimOp>(use.getOwner())) {
431431
numUsers++;
432432
}
433433
}
@@ -549,7 +549,8 @@ static LogicalResult hasDestructiveUpdateLoopPattern(scf::ForOp forOp,
549549
[&](tensor::InsertSliceOp subTensorInsertOp) {
550550
return subTensorInsertOp.dest() == arg;
551551
})
552-
.Case<memref::DimOp, scf::YieldOp>([&](auto op) { return true; })
552+
.Case<memref::DimOp, scf::YieldOp, tensor::DimOp>(
553+
[&](auto op) { return true; })
553554
.Default([&](Operation *op) { return false; });
554555
};
555556
if (llvm::all_of(arg.getUses(), isDestructiveUpdateUses)) {
@@ -1616,7 +1617,7 @@ void LinalgBufferizePass::runOnOperation() {
16161617
.Case<tensor::ExtractOp>([&](tensor::ExtractOp op) {
16171618
return convertTensorExtractOp(b, op, bvm);
16181619
})
1619-
.Case<memref::DimOp, vector::TransferReadOp>([&](auto op) {
1620+
.Case<vector::TransferReadOp>([&](auto op) {
16201621
for (unsigned i : llvm::seq<unsigned>(0, op->getNumOperands())) {
16211622
Value operand = op->getOperand(i);
16221623
if (operand.getType().isa<RankedTensorType>()) {
@@ -1626,6 +1627,14 @@ void LinalgBufferizePass::runOnOperation() {
16261627
}
16271628
return success();
16281629
})
1630+
.Case<tensor::DimOp>([&](tensor::DimOp dimOp) {
1631+
Value operand = dimOp.source();
1632+
Value remappedVal = bvm.lookupOrNull(operand);
1633+
Value newDimOp = b.create<memref::DimOp>(dimOp.getLoc(), remappedVal,
1634+
dimOp.index());
1635+
dimOp.replaceAllUsesWith(newDimOp);
1636+
return success();
1637+
})
16291638
.Case<scf::ForOp>([&](scf::ForOp forOp) {
16301639
// To canonicalize the `scf.for` tensor result/operand/yield value
16311640
// away, forward the init argument to the yeild of the loop.

‎iree/compiler/Codegen/Common/test/linalg_bufferize.mlir‎

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -970,8 +970,8 @@ func @subtensor_insert() {
970970
%2 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xi32>
971971
%3 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32>
972972
%4 = flow.dispatch.tensor.load %1, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32>
973-
%5 = memref.dim %3, %c0 : tensor<?x?xi32>
974-
%6 = memref.dim %3, %c1 : tensor<?x?xi32>
973+
%5 = tensor.dim %3, %c0 : tensor<?x?xi32>
974+
%6 = tensor.dim %3, %c1 : tensor<?x?xi32>
975975
%7 = tensor.insert_slice %3 into %4[3, 4] [%5, %6] [1, 1] : tensor<?x?xi32> into tensor<?x?xi32>
976976
flow.dispatch.tensor.store %7, %2, offsets = [], sizes = [], strides = [] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>
977977
return
@@ -1118,8 +1118,8 @@ func @gather() {
11181118
%2 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xf32>
11191119
%4 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = []: !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
11201120
%5 = flow.dispatch.tensor.load %1, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?xi32> -> tensor<?xi32>
1121-
%d0 = memref.dim %5, %c0 : tensor<?xi32>
1122-
%d1 = memref.dim %4, %c1 : tensor<?x?xf32>
1121+
%d0 = tensor.dim %5, %c0 : tensor<?xi32>
1122+
%d1 = tensor.dim %4, %c1 : tensor<?x?xf32>
11231123
%3 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
11241124
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<?xi32>) outs(%3 : tensor<?x?xf32>) {
11251125
^bb0( %arg2: i32, %arg3: f32): // no predecessors
@@ -1203,8 +1203,8 @@ func @read_only_subtensor() {
12031203
%workgroup_count_y = hal.interface.workgroup.count[1] : index
12041204
%5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
12051205
%6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
1206-
%dim0 = memref.dim %2, %c0 : tensor<?x?xf32>
1207-
%dim1 = memref.dim %2, %c1 : tensor<?x?xf32>
1206+
%dim0 = tensor.dim %2, %c0 : tensor<?x?xf32>
1207+
%dim1 = tensor.dim %2, %c1 : tensor<?x?xf32>
12081208
scf.for %arg0 = %5 to %dim0 step %6 {
12091209
%7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
12101210
%8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
@@ -1263,7 +1263,7 @@ func @reshape_read_only() {
12631263
%2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
12641264
%3 = linalg.tensor_collapse_shape %2 [[0, 1]]
12651265
: tensor<?x?xf32> into tensor<?xf32>
1266-
%4 = memref.dim %3, %c0 : tensor<?xf32>
1266+
%4 = tensor.dim %3, %c0 : tensor<?xf32>
12671267
%5 = linalg.init_tensor [%4] : tensor<?xf32>
12681268
%6 = linalg.generic {
12691269
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
@@ -1496,8 +1496,8 @@ func @rank_reduced_subtensor_insert() {
14961496
%1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<readwrite:?x?x?xf32>
14971497
%2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
14981498
%3 = flow.dispatch.tensor.load %1, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readwrite:?x?x?xf32> -> tensor<?x?x?xf32>
1499-
%4 = memref.dim %3, %c1 : tensor<?x?x?xf32>
1500-
%5 = memref.dim %3, %c2 : tensor<?x?x?xf32>
1499+
%4 = tensor.dim %3, %c1 : tensor<?x?x?xf32>
1500+
%5 = tensor.dim %3, %c2 : tensor<?x?x?xf32>
15011501
%6 = tensor.insert_slice %2 into %3[0, 0, 0] [1, %4, %5] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
15021502
flow.dispatch.tensor.store %6, %1, offsets = [], sizes = [], strides = [] : tensor<?x?x?xf32> -> !flow.dispatch.tensor<readwrite:?x?x?xf32>
15031503
return
@@ -2035,9 +2035,9 @@ module {
20352035
%18 = tensor.extract_slice %8[%arg2, %arg6] [%16, %17] [1, 1] : tensor<?x144xf32> to tensor<?x?xf32>
20362036
%19 = affine.min #map5(%arg4)
20372037
%20 = tensor.extract_slice %10[%arg6, %arg4] [%17, %19] [1, 1] : tensor<144x?xf32> to tensor<?x?xf32>
2038-
%21 = memref.dim %arg7, %c0 : tensor<?x?xf32>
2038+
%21 = tensor.dim %arg7, %c0 : tensor<?x?xf32>
20392039
%22 = affine.min #map6(%21, %arg2)
2040-
%23 = memref.dim %arg7, %c1 : tensor<?x?xf32>
2040+
%23 = tensor.dim %arg7, %c1 : tensor<?x?xf32>
20412041
%24 = affine.min #map6(%23, %arg4)
20422042
%25 = tensor.extract_slice %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
20432043
%26 = linalg.matmul {__internal_linalg_transform__ = "workgroup_l1_tile", lowering.config = #config1} ins(%18, %20 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%25 : tensor<?x?xf32>) -> tensor<?x?xf32>
@@ -2140,9 +2140,9 @@ module {
21402140
%18 = tensor.extract_slice %8[%arg2, %arg6] [%16, %17] [1, 1] : tensor<?x144xf32> to tensor<?x?xf32>
21412141
%19 = affine.min #map5(%arg4)
21422142
%20 = tensor.extract_slice %10[%arg6, %arg4] [%17, %19] [1, 1] : tensor<144x?xf32> to tensor<?x?xf32>
2143-
%21 = memref.dim %arg7, %c0 : tensor<?x?xf32>
2143+
%21 = tensor.dim %arg7, %c0 : tensor<?x?xf32>
21442144
%22 = affine.min #map6(%21, %arg2)
2145-
%23 = memref.dim %arg7, %c1 : tensor<?x?xf32>
2145+
%23 = tensor.dim %arg7, %c1 : tensor<?x?xf32>
21462146
%24 = affine.min #map6(%23, %arg4)
21472147
%25 = tensor.extract_slice %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
21482148
%26 = linalg.matmul {__internal_linalg_transform__ = "workgroup_l1_tile", lowering.config = #config1} ins(%18, %20 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%25 : tensor<?x?xf32>) -> tensor<?x?xf32>
@@ -2241,9 +2241,9 @@ module {
22412241
%18 = tensor.extract_slice %8[%arg2, %arg6] [%16, %17] [1, 1] : tensor<?x144xf32> to tensor<?x?xf32>
22422242
%19 = affine.min #map5(%arg4)
22432243
%20 = tensor.extract_slice %10[%arg6, %arg4] [%17, %19] [1, 1] : tensor<144x?xf32> to tensor<?x?xf32>
2244-
%21 = memref.dim %arg7, %c0 : tensor<?x?xf32>
2244+
%21 = tensor.dim %arg7, %c0 : tensor<?x?xf32>
22452245
%22 = affine.min #map6(%21, %arg2)
2246-
%23 = memref.dim %arg7, %c1 : tensor<?x?xf32>
2246+
%23 = tensor.dim %arg7, %c1 : tensor<?x?xf32>
22472247
%24 = affine.min #map6(%23, %arg4)
22482248
%25 = tensor.extract_slice %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
22492249
%26 = linalg.matmul {__internal_linalg_transform__ = "workgroup_l1_tile", lowering.config = #config1} ins(%18, %20 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%25 : tensor<?x?xf32>) -> tensor<?x?xf32>

‎iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,12 +404,12 @@ namespace {
404404
// shapex.ranked_dim(flow.dispatch.shape(%x), %const)
405405
// ``
406406
struct ConvertDimOfDispatchInputLoadToDispatchShape
407-
: public OpRewritePattern<memref::DimOp> {
407+
: public OpRewritePattern<tensor::DimOp> {
408408
using OpRewritePattern::OpRewritePattern;
409409

410-
LogicalResult matchAndRewrite(memref::DimOp op,
410+
LogicalResult matchAndRewrite(tensor::DimOp op,
411411
PatternRewriter &rewriter) const override {
412-
auto loadOp = op.memrefOrTensor().getDefiningOp<DispatchTensorLoadOp>();
412+
auto loadOp = op.source().getDefiningOp<DispatchTensorLoadOp>();
413413
if (!loadOp) return failure();
414414

415415
Optional<int64_t> constantIndex = op.getConstantIndex();

‎iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func @convertDimOfDispatchInputLoadToDispatchShape(%arg0: !flow.dispatch.tensor<
6666
// CHECK-NEXT: "test.sink"(%[[DIM]]) : (index) -> ()
6767
%tensor = flow.dispatch.tensor.load %arg0, offsets=[], sizes=[], strides=[] : !flow.dispatch.tensor<readonly:?xf32> -> tensor<?xf32>
6868
%c0 = constant 0 : index
69-
%dim = memref.dim %tensor, %c0 : tensor<?xf32>
69+
%dim = tensor.dim %tensor, %c0 : tensor<?xf32>
7070
"test.sink"(%dim) : (index) -> ()
7171
return
7272
}

‎iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ static SmallVector<Value, 4> getDynamicDimValues(OpBuilder &b, Location loc,
122122
SmallVector<Value, 4> dynamicDims;
123123
for (auto dim : llvm::enumerate(v.getType().cast<ShapedType>().getShape())) {
124124
if (dim.value() != ShapedType::kDynamicSize) continue;
125-
dynamicDims.push_back(b.createOrFold<memref::DimOp>(loc, v, dim.index()));
125+
dynamicDims.push_back(b.createOrFold<tensor::DimOp>(loc, v, dim.index()));
126126
}
127127
return dynamicDims;
128128
}

‎iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ static bool hasDestructiveUpdateSubTensorUses(
6969
writes.push_back(subTensorInsertOp);
7070
continue;
7171
}
72-
if (auto dimOp = dyn_cast<memref::DimOp>(u.getOwner())) {
72+
if (auto dimOp = dyn_cast<tensor::DimOp>(u.getOwner())) {
7373
continue;
7474
}
7575
LLVM_DEBUG(llvm::dbgs() << "found non-destructive update pattern use: "

‎iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp‎

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "llvm/Support/CommandLine.h"
1818
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
1919
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20-
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2120
#include "mlir/Dialect/SCF/SCF.h"
2221
#include "mlir/Dialect/StandardOps/IR/Ops.h"
2322
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -104,7 +103,7 @@ struct DispatchLinalgOnTensorsPass
104103
void getDependentDialects(DialectRegistry &registry) const override {
105104
registry
106105
.insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
107-
memref::MemRefDialect, scf::SCFDialect, ShapeDialect>();
106+
scf::SCFDialect, ShapeDialect, tensor::TensorDialect>();
108107
}
109108
DispatchLinalgOnTensorsPass() = default;
110109
DispatchLinalgOnTensorsPass(const DispatchLinalgOnTensorsPass &pass) {}
@@ -665,7 +664,7 @@ static LogicalResult legalizeDispatchWorkgroupOperands(
665664
if (auto rt = operand.getType().dyn_cast<RankedTensorType>()) {
666665
for (unsigned i = 0; i < rt.getRank(); ++i) {
667666
if (!rt.isDynamicDim(i)) continue;
668-
auto dim = builder.createOrFold<memref::DimOp>(dispatchOp.getLoc(),
667+
auto dim = builder.createOrFold<tensor::DimOp>(dispatchOp.getLoc(),
669668
operand, i);
670669
operandDynamicDims.push_back(dim);
671670
}
@@ -715,7 +714,7 @@ static Optional<SmallVector<SmallVector<Value, 4>, 1>> computeOutputShape(
715714

716715
static bool hasOnlyDimUses(Operation *op) {
717716
return llvm::all_of(op->getUsers(), [&](Operation *user) {
718-
return isa<memref::DimOp>(user);
717+
return isa<tensor::DimOp>(user);
719718
});
720719
}
721720

@@ -808,7 +807,7 @@ struct TileAndDistributeOnTensorsPattern
808807

809808
rewriter.replaceOpWithIf(op, dispatchOp.getResults(),
810809
[&](OpOperand &operand) {
811-
return !isa<memref::DimOp>(operand.getOwner());
810+
return !isa<tensor::DimOp>(operand.getOwner());
812811
});
813812
return success();
814813
}
@@ -865,7 +864,7 @@ static Optional<SmallVector<SmallVector<Value>>> getResultShapes(
865864
SmallVector<Value> shape;
866865
for (auto dim :
867866
llvm::seq<int64_t>(0, v.getType().cast<ShapedType>().getRank())) {
868-
shape.push_back(rewriter.createOrFold<memref::DimOp>(loc, v, dim));
867+
shape.push_back(rewriter.createOrFold<tensor::DimOp>(loc, v, dim));
869868
}
870869
return shape;
871870
};
@@ -898,7 +897,7 @@ struct MakeDispatchWorkgroupsOp : public RewritePattern {
898897
llvm::all_of(op->getUsers(), [](Operation *user) {
899898
return isDispatchableOp(user) ||
900899
user->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>() ||
901-
isa<IREE::Flow::DispatchWorkgroupsOp, memref::DimOp>(user);
900+
isa<IREE::Flow::DispatchWorkgroupsOp, tensor::DimOp>(user);
902901
})) {
903902
return failure();
904903
}
@@ -966,7 +965,7 @@ struct MakeDispatchWorkgroupsOp : public RewritePattern {
966965
rewriter.replaceOpWithIf(op, dispatchOp.getOperation()->getResults(),
967966
[&](OpOperand &operand) {
968967
Operation *user = operand.getOwner();
969-
return !isa<memref::DimOp>(user);
968+
return !isa<tensor::DimOp>(user);
970969
});
971970
return success();
972971
}

0 commit comments

Comments
 (0)