Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d05d421

Browse files
committed
[mlir] Add partial lowering of shape.cstr_broadcastable.
Because cstr operations allow more instruction reordering than asserts, we only lower cstr_broadcastable to std ops with cstr_require. This ensures that the more drastic lowering to asserts can happen specifically with the user's desire. Differential Revision: https://reviews.llvm.org/D89325
1 parent 952ddc9 commit d05d421

File tree

4 files changed

+78
-1
lines changed

4 files changed

+78
-1
lines changed

mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
set(LLVM_TARGET_DEFINITIONS ShapeToStandard.td)
2+
mlir_tablegen(ShapeToStandard.cpp.inc -gen-rewriters)
3+
add_public_tablegen_target(ShapeToStandardIncGen)
4+
15
add_mlir_conversion_library(MLIRShapeToStandard
26
ConvertShapeConstraints.cpp
37
ShapeToStandard.cpp
@@ -7,6 +11,7 @@ add_mlir_conversion_library(MLIRShapeToStandard
711

812
DEPENDS
913
MLIRConversionPassIncGen
14+
ShapeToStandardIncGen
1015

1116
LINK_COMPONENTS
1217
Core

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,11 @@ class ToExtentTensorOpConversion
566566
};
567567
} // namespace
568568

569+
namespace {
570+
/// Import the Shape Ops to Std Patterns.
571+
#include "ShapeToStandard.cpp.inc"
572+
} // namespace
573+
569574
namespace {
570575
/// Conversion pass.
571576
class ConvertShapeToStandardPass
@@ -580,7 +585,7 @@ void ConvertShapeToStandardPass::runOnOperation() {
580585
MLIRContext &ctx = getContext();
581586
ConversionTarget target(ctx);
582587
target.addLegalDialect<StandardOpsDialect, SCFDialect>();
583-
target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
588+
target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
584589

585590
// Setup conversion patterns.
586591
OwningRewritePatternList patterns;
@@ -595,6 +600,7 @@ void ConvertShapeToStandardPass::runOnOperation() {
595600
void mlir::populateShapeToStandardConversionPatterns(
596601
OwningRewritePatternList &patterns, MLIRContext *ctx) {
597602
// clang-format off
603+
populateWithGenerated(ctx, patterns);
598604
patterns.insert<
599605
AnyOpConversion,
600606
BinaryOpConversion<AddOp, AddIOp>,
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//==-- ShapeToStandard.td - Shape to Standard Patterns -------*- tablegen -*==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines Patterns to lower Shape ops to Std.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_SHAPETOSTANDARD_TD
14+
#define MLIR_CONVERSION_SHAPETOSTANDARD_TD
15+
16+
include "mlir/Dialect/Shape/IR/ShapeOps.td"
17+
18+
def BroadcastableStringAttr : NativeCodeCall<[{
19+
$_builder.getStringAttr("required broadcastable shapes")
20+
}]>;
21+
22+
def : Pat<(Shape_CstrBroadcastableOp $LHS, $RHS),
23+
(Shape_CstrRequireOp
24+
(Shape_IsBroadcastableOp $LHS, $RHS),
25+
(BroadcastableStringAttr))>;
26+
27+
#endif // MLIR_CONVERSION_SHAPETOSTANDARD_TD

mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,42 @@ func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
420420
// CHECK: }
421421
// CHECK: return %[[ALL_RESULT]] : i1
422422
// CHECK: }
423+
424+
// -----
425+
426+
func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
427+
%0 = shape.cstr_broadcastable %a, %b : tensor<?xindex>, tensor<?xindex>
428+
return %0 : !shape.witness
429+
}
430+
431+
// CHECK-LABEL: func @broadcast(
432+
// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
433+
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
434+
// CHECK: %[[C0:.*]] = constant 0 : index
435+
// CHECK: %[[C1:.*]] = constant 1 : index
436+
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
437+
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
438+
// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
439+
// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
440+
// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
441+
// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
442+
// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
443+
// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
444+
// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
445+
// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
446+
// CHECK: %[[TRUE:.*]] = constant true
447+
// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[VAL_16:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
448+
// CHECK: %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor<?xindex>
449+
// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index
450+
// CHECK: %[[LHS_EXTENT_INDEX:.*]] = subi %[[VAL_16]], %[[RANK_DIFF]] : index
451+
// CHECK: %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor<?xindex>
452+
// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index
453+
// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
454+
// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
455+
// CHECK: %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1
456+
// CHECK: %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1
457+
// CHECK: scf.yield %[[NEW_ALL_SO_FAR]] : i1
458+
// CHECK: }
459+
// CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes"
460+
// CHECK: return %[[RESULT]] : !shape.witness
461+
// CHECK: }

0 commit comments

Comments
 (0)