Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit dbb42c9

Browse files
[mlir] update memref.cast cast compatible check (#179313)
Updating memref.cast check regarding if input and output are valid for casting. Currently in case of casting between dynamic and static dims with different strides, the return value of the check is not symmetric and depends if casting for dynamic to static or vice versa. Updating the check logic to make this symmetric.
1 parent 6933962 commit dbb42c9

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
766766
if (!checkCompatible(aOffset, bOffset))
767767
return false;
768768
for (const auto &[index, aStride] : enumerate(aStrides)) {
769-
if (aT.getDimSize(index) == 1)
769+
if (aT.getDimSize(index) == 1 || bT.getDimSize(index) == 1)
770770
continue;
771771
if (!checkCompatible(aStride, bStrides[index]))
772772
return false;

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ func.func @memref_alloca_scope() {
211211
}
212212

213213
// CHECK-LABEL: func @memref_cast(%arg0
214-
func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64x16x4xf32, strided<[64, 4, 1], offset: 0>>) {
214+
func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64x16x4xf32, strided<[64, 4, 1], offset: 0>>, %arg3 : memref<4x1x8xf32, strided<[32, 16, 1]>>, %arg4 : memref<4x?x8xf32, strided<[32, 8, 1]>>) {
215215
// CHECK: memref.cast %{{.*}} : memref<4xf32> to memref<?xf32>
216216
%0 = memref.cast %arg0 : memref<4xf32> to memref<?xf32>
217217

@@ -229,6 +229,12 @@ func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memr
229229

230230
// CHECK: memref.cast %{{.*}} : memref<*xf32> to memref<4xf32>
231231
%5 = memref.cast %4 : memref<*xf32> to memref<4xf32>
232+
233+
// CHECK: memref.cast %{{.*}} : memref<4x1x8xf32, strided<[32, 16, 1]>> to memref<4x?x8xf32, strided<[32, 8, 1]>>
234+
%6 = memref.cast %arg3 : memref<4x1x8xf32, strided<[32, 16, 1]>> to memref<4x?x8xf32, strided<[32, 8, 1]>>
235+
236+
// CHECK: memref.cast %{{.*}} : memref<4x?x8xf32, strided<[32, 8, 1]>> to memref<4x1x8xf32, strided<[32, 16, 1]>>
237+
%7 = memref.cast %arg4 : memref<4x?x8xf32, strided<[32, 8, 1]>> to memref<4x1x8xf32, strided<[32, 16, 1]>>
232238
return
233239
}
234240

0 commit comments

Comments
 (0)