Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3964d94

Browse files
authored
[mlir] transform.structured.match fix tilingIface condition (#65337)
The matching condition for payload ops implementing TilingInterface was inverted. Fix it and add a test.
1 parent fa31ce5 commit 3964d94

2 files changed

Lines changed: 19 additions & 1 deletion

File tree

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
11451145
!isa<LinalgOp>(op))
11461146
return;
11471147
if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1148-
isa<TilingInterface>(op))
1148+
!isa<TilingInterface>(op))
11491149
return;
11501150
if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
11511151
!isa<LoopLikeOpInterface>(op))

mlir/test/Dialect/Linalg/transform-op-match.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,24 @@ transform.sequence failures(propagate) {
3939

4040
// -----
4141

42+
func.func @foo(%a: tensor<4x4xf32>, %b: tensor<4x4xf32>, %c: tensor<4x4xf32>) {
43+
%c0 = arith.constant 0.0 : f32
44+
// expected-remark @below {{tileable}}
45+
%r = linalg.fill ins(%c0 : f32) outs(%c : tensor<4x4xf32>) -> tensor<4x4xf32>
46+
// expected-remark @below {{tileable}}
47+
linalg.matmul ins(%a, %b : tensor<4x4xf32>, tensor<4x4xf32>) outs(%r : tensor<4x4xf32>) -> tensor<4x4xf32>
48+
return
49+
}
50+
51+
transform.sequence failures(propagate) {
52+
^bb0(%arg0: !transform.any_op):
53+
%matched = transform.structured.match interface{TilingInterface} in %arg0 : (!transform.any_op) -> !transform.any_op
54+
transform.test_print_remark_at_operand %matched, "tileable" : !transform.any_op
55+
transform.yield
56+
}
57+
58+
// -----
59+
4260
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
4361
#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
4462
func.func @match_complex_attribute(%arg0: tensor<12x128x32xf32>)

0 commit comments

Comments
 (0)