Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1ce189f

Browse files
Prakalp Srivastavatensorflower-gardener
authored andcommitted
Refine result type of std.tensor_cast op during shape inference.
std.tensor_cast operand and result are cast compatible, so it is safe to refine the result type of std.tensor_cast to be same as operand. PiperOrigin-RevId: 308913809 Change-Id: I414ad8f32d15e864faddff75e5bf2eaa4bb95262
1 parent d89b24f commit 1ce189f

2 files changed

Lines changed: 13 additions & 0 deletions

File tree

tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,15 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
291291
return %0 : tensor<?x?x?xf32>
292292
}
293293

294+
// Tests that tensor_cast result shapes are refined.
295+
// CHECK-LABEL: func @tensor_cast_refine
296+
func @tensor_cast_refine(%arg0: tensor<4xi32>) -> (tensor<*xi32>) {
297+
// CHECK: tensor_cast
298+
// CHECK-SAME: tensor<4xi32> to tensor<4xi32>
299+
%0 = tensor_cast %arg0 : tensor<4xi32> to tensor<*xi32>
300+
return %0 : tensor<*xi32>
301+
}
302+
294303
// CHECK-LABEL: func @fold_cast
295304
func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> {
296305
// CHECK-NOT: Cast

tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ bool InferShapeForNonTFDialectOperation(Operation* op, Dialect* tf_dialect) {
185185
iter_sink.getOperands().drop_front().take_front(), iter_source,
186186
tf_dialect);
187187
}
188+
if (auto tensor_cast = dyn_cast<mlir::TensorCastOp>(op)) {
189+
return InferShapeForPassThroughOps(
190+
tensor_cast.getOperation()->getOperands(), op, tf_dialect);
191+
}
188192
return false;
189193
}
190194

0 commit comments

Comments
 (0)