diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 63586a5bb8bbb..a5a659abbbb9f 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -216,6 +216,67 @@ def contract( ) +# Extend and shadow the TableGen-derived version to make sure correct default +# indexing_maps are derived (as there is no mechanism for doing so given the +# Python API bypasses the C++-builders). +class ElementwiseOp_(ElementwiseOp): + def __init__( + self, + result_tensors, + inputs, + outputs, + kind, + *, + indexing_maps=None, + loc=None, + ip=None, + ): + if indexing_maps is None: + inputs = [_get_op_result_or_value(in_) for in_ in inputs] + for in0, in1 in zip(inputs[:-1], inputs[1:]): + assert in0.type == in1.type + output = _get_op_result_or_value(outputs[0]) + assert inputs[0].type == output.type + num_args = len(inputs) + 1 + indexing_maps = [AffineMap.get_identity(output.type.rank)] * num_args + + super().__init__( + result_tensors=result_tensors, + inputs=inputs, + outputs=outputs, + kind=kind, + indexing_maps=indexing_maps, + loc=loc, + ip=ip, + ) + + +ElementwiseOp = ElementwiseOp_ + + +def elementwise( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + kind: Union[ElementwiseKind, Attribute], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, +): + ins = [_get_op_result_or_value(input) for input in ins] + if len(outs) != 1: + raise ValueError(f"{outs=} must have length 1.") + init = _get_op_result_or_value(outs[0]) + result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] + + op = ElementwiseOp( + result_tensors=result_types, + inputs=ins, + outputs=[init], + kind=kind, + indexing_maps=indexing_maps, + ) + fill_builtin_region(op.operation) + return _get_op_result_or_op_results(op) + + def pack( source, dest, diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index e32a911b24b11..5a163474210a6 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -606,3 +606,189 @@ def tensor_pack(src, dst): # CHECK: return %[[VAL_4]] : tensor<128x128xf32> # CHECK: } print(module) + + +# CHECK-LABEL: TEST: testElementwiseOp +@run +def testElementwiseOp(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + rect_shape = (8, 16) + vert_line_shape = (8,) + hor_line_shape = (16,) + transposed_rect_shape = (16, 8) + + # CHECK-DAG: #[[$IdentMap2D:.*]] = affine_map<(d0, d1) -> (d0, d1)> + # CHECK-DAG: #[[$TransMap2D:.*]] = affine_map<(d0, d1) -> (d1, d0)> + # CHECK-DAG: #[[$VertLineBCastMap:.*]] = affine_map<(d0, d1) -> (d0)> + # CHECK-DAG: #[[$HorLineBCastMap:.*]] = affine_map<(d0, d1) -> (d1)> + + ident_map_2d = AffineMap.get_identity(2) + transposed_map_2d = AffineMap.get_permutation((1, 0)) + vert_line_bcast_map = AffineMap.get(2, 0, [AffineDimExpr.get(0)]) + hor_line_bcast_map = AffineMap.get(2, 0, [AffineDimExpr.get(1)]) + + # CHECK: func.func @elementwise_op( + @func.FuncOp.from_py_func( + # CHECK-SAME: %[[Rect:.*]]: tensor<8x16xf32>, + RankedTensorType.get(rect_shape, f32), + # CHECK-SAME: %[[RectMem:.*]]: memref<8x16xf32>, + MemRefType.get(rect_shape, f32), + # CHECK-SAME: %[[VertLine:.*]]: tensor<8xf32>, + RankedTensorType.get(vert_line_shape, f32), + # CHECK-SAME: %[[VertLineMem:.*]]: memref<8xf32>, + MemRefType.get(vert_line_shape, f32), + # CHECK-SAME: %[[HorLine:.*]]: tensor<16xf32>, + RankedTensorType.get(hor_line_shape, f32), + # CHECK-SAME: %[[HorLineMem:.*]]: memref<16xf32>, + MemRefType.get(hor_line_shape, f32), + # CHECK-SAME: %[[TransRect:.*]]: tensor<16x8xf32>, + RankedTensorType.get(transposed_rect_shape, f32), + # CHECK-SAME: %[[TransRectMem:.*]]: memref<16x8xf32>) + MemRefType.get(transposed_rect_shape, f32), + ) + def elementwise_op( + rect, + rect_mem, + vert_line, + vert_line_mem, + hor_line, + hor_line_mem, + trans_rect, + trans_rect_mem, + ): + # CHECK: %[[OutRect:.*]] = tensor.empty() : tensor<8x16xf32> + out_rect = tensor.EmptyOp(rect_shape, f32) + # CHECK: %[[OutRectMem:.*]] = memref.alloca() : memref<8x16xf32> + out_rect_mem = memref.alloca(MemRefType.get(rect_shape, f32), [], []) + + if _inferred_affine_maps := True: + # CHECK: linalg.elementwise + # CHECK-SAME: kind=#linalg.elementwise_kind + # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>) + # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32> + op1 = linalg.ElementwiseOp( + result_tensors=(out_rect.result.type,), + inputs=(rect,), + outputs=(out_rect,), + kind=linalg.ElementwiseKind.exp, + ) + linalg.fill_builtin_region(op1.operation) + + # CHECK: linalg.elementwise + # CHECK-SAME: kind=#linalg.elementwise_kind + # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>) + # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32> + linalg.elementwise( + rect, + outs=(out_rect,), + kind=linalg.ElementwiseKind.exp, + ) + + # CHECK: linalg.elementwise + # CHECK-SAME: kind=#linalg.elementwise_kind + # CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>) + # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>) + linalg.elementwise( + rect_mem, + outs=(out_rect_mem,), + kind=linalg.ElementwiseKind.exp, + ) + + if _explicit_ident_affine_maps := True: + # Same as above but with default identity indexing_maps explicitly provided. + # CHECK: linalg.elementwise + # CHECK-SAME: kind=#linalg.elementwise_kind + # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>) + # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32> + op3 = linalg.ElementwiseOp( + result_tensors=(out_rect.result.type,), + inputs=(rect,), + outputs=(out_rect,), + kind=linalg.ElementwiseKind.exp, + indexing_maps=[ident_map_2d, ident_map_2d], + ) + linalg.fill_builtin_region(op3.operation) + + # CHECK: linalg.elementwise + # CHECK-SAME: kind=#linalg.elementwise_kind + # CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>) + # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>) + linalg.elementwise( + rect_mem, + outs=(out_rect_mem,), + kind=linalg.ElementwiseKind.exp, + indexing_maps=[ident_map_2d, ident_map_2d], + ) + + if _ops_with_non_ident_input_maps := True: + # CHECK: linalg.elementwise kind=#linalg.elementwise_kind + # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$IdentMap2D]]] + # CHECK-SAME: ins(%[[VertLine]] : tensor<8xf32>) + # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32> + op4 = linalg.ElementwiseOp( + result_tensors=(out_rect.result.type,), + inputs=(vert_line,), + outputs=(out_rect,), + kind=linalg.ElementwiseKind.exp, + indexing_maps=[vert_line_bcast_map, ident_map_2d], + ) + linalg.fill_builtin_region(op4.operation) + + # CHECK: linalg.elementwise kind=#linalg.elementwise_kind + # CHECK-SAME: indexing_maps = [#[[$IdentMap2D]], #[[$VertLineBCastMap]], #[[$IdentMap2D]]] + # CHECK-SAME: ins(%[[Rect]], %[[VertLine]] : tensor<8x16xf32>, tensor<8xf32>) + # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32> + op4 = linalg.ElementwiseOp( + result_tensors=(out_rect.result.type,), + inputs=(rect, vert_line), + outputs=(out_rect,), + kind=linalg.ElementwiseKind.add, + indexing_maps=[ident_map_2d, vert_line_bcast_map, ident_map_2d], + ) + linalg.fill_builtin_region(op4.operation) + + # CHECK: linalg.elementwise kind=#linalg.elementwise_kind
+ # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$IdentMap2D]]] + # CHECK-SAME: ins(%[[VertLine]], %[[HorLine]] : tensor<8xf32>, tensor<16xf32>) + # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32> + linalg.elementwise( + vert_line, + hor_line, + outs=(out_rect,), + kind=linalg.ElementwiseKind.div, + indexing_maps=[ + vert_line_bcast_map, + hor_line_bcast_map, + ident_map_2d, + ], + ) + + if _ops_with_non_ident_and_transposed_input_maps := True: + # CHECK: %[[VertLineBoolsMem:.*]] = memref.alloca() : memref<8xi1> + vert_line_bools_mem = memref.alloca( + MemRefType.get(vert_line_shape, IntegerType.get_signless(1)), + [], + [], + ) + # CHECK: linalg.elementwise kind=#linalg.elementwise_kind