diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 07a192f7b8606..3597209d7f90c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1175,6 +1175,14 @@ FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp); +/// Same as the above but for Fhwc channel orderings in the filter. In this case +/// the matrix multiplication is actually a row-wise dot-product rather than a +/// row-column dot-product. This is to avoid transposing the filter matrix which +/// would be required for a regular matrix multiplication to produce the correct +/// output dimensions. +FailureOr> +rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp); + /// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no /// reduction among the input channels so each convolution can be a /// matrix-vector product and by transposing both input filter so channels are diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 9ce780d3d249c..8508507871d0c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3118,6 +3118,9 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( .Case([&](linalg::Conv2DNhwcHwcfOp op) { return rewriteInIm2Col(rewriter, op); }) + .Case([&](linalg::Conv2DNhwcFhwcOp op) { + return rewriteInIm2Col(rewriter, op); + }) .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) { return rewriteInIm2Col(rewriter, op); }) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp index 275e78aaa73dd..e7629d79494bd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -494,6 +494,141 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { reshapedResult.getOperation()); } +FailureOr> +rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { + auto inputType = cast(convOp.getInputs()[0].getType()); + auto filterType = cast(convOp.getInputs()[1].getType()); + auto outputType = cast(convOp.getOutputs()[0].getType()); + + if (!filterType.hasStaticShape()) + return rewriter.notifyMatchFailure( + convOp, "expected a static shape for the filter"); + + if (!inputType.hasStaticShape()) + return rewriter.notifyMatchFailure(convOp, + "expected a static shape for the input"); + + // TODO: Support dilation. + if (!hasAllOneValues(convOp.getDilations())) + return rewriter.notifyMatchFailure(convOp, + "expected all ones for dilations"); + + MLIRContext *context = rewriter.getContext(); + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + + ArrayRef filterShape = filterType.getShape(); + ArrayRef outputShape = outputType.getShape(); + + int64_t n = outputShape[0]; + int64_t oh = outputShape[1]; + int64_t ow = outputShape[2]; + int64_t oc = outputShape[3]; + int64_t fh = filterShape[1]; + int64_t fw = filterShape[2]; + int64_t ic = filterShape[3]; + + Location loc = convOp.getLoc(); + + // Reshape output and filter to the LHS and result of a "row-wise" matrix + // multiplication. + SmallVector filterReassocIndices = {{0}, {1, 2, 3}}; + auto reshapedFilterType = + RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType()); + Value reshapedFilter = rewriter.create( + loc, reshapedFilterType, filter, filterReassocIndices); + + SmallVector outputReassocIndices = {{0}, {1, 2}, {3}}; + RankedTensorType reshapedOutputType = + RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()); + Value reshapedOutput = rewriter.create( + loc, reshapedOutputType, output, outputReassocIndices); + + SmallVector colTensorShape = {n, oh * ow, fh * fw * ic}; + Value colTensor = rewriter.create( + loc, colTensorShape, inputType.getElementType()); + + // Convert the input to a (BMK) column tensor. + auto nloops = colTensorShape.size(); + + auto parallel = utils::IteratorType::parallel; + auto reduction = utils::IteratorType::reduction; + SmallVector img2colIterators(nloops, parallel); + + SmallVector img2colIndexingMaps = { + AffineMap::getMultiDimIdentityMap(nloops, context)}; + + auto img2ColTensor = rewriter.create( + loc, colTensor.getType(), + /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, + img2colIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + // Get the iterators named based on the matmul (batch, m, k). + Value bIndex = nestedBuilder.create(loc, 0); + Value mIndex = nestedBuilder.create(loc, 1); + Value kIndex = nestedBuilder.create(loc, 2); + + // Recover the original iteration indices from the problem/input sizes. + SmallVector mIndices = unrollIndex( + nestedBuilder, nestedLoc, mIndex, ArrayRef{oh, ow}); + auto ohIndex = mIndices[0]; + auto owIndex = mIndices[1]; + + SmallVector kIndices = unrollIndex( + nestedBuilder, nestedLoc, kIndex, ArrayRef{fh, fw, ic}); + auto fhIndex = kIndices[0]; + auto fwIndex = kIndices[1]; + auto icIndex = kIndices[2]; + + // Extract the input element corresponding to the expanded indices. + Value hIndex = + getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, + convOp.getStrides().getValues()[0]); + Value wIndex = + getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, + convOp.getStrides().getValues()[1]); + + // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] + SmallVector extractionIndices{bIndex, hIndex, wIndex, icIndex}; + Value inputVal = nestedBuilder.create( + loc, input, extractionIndices); + nestedBuilder.create(nestedLoc, inputVal); + }); + + // Because we didn't transpose the filters we don't actually have a batched + // matrix multiply. Instead, we have an operation consisting of "row-wise" dot + // products. + AffineExpr bDim, mDim, nDim, kDim; + bindDims(context, bDim, mDim, nDim, kDim); + auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context); + auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context); + auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context); + SmallVector genericIterators = {parallel, parallel, + parallel, reduction}; + + auto genericOp = rewriter.create( + loc, reshapedOutputType, + /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter}, + /*outputs=*/ValueRange{reshapedOutput}, + ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value mul = + createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder); + Value add = createAdd(loc, mul, args[2], nestedBuilder); + nestedBuilder.create(nestedLoc, add); + }); + Value result = genericOp.getResults().front(); + + auto reshapedResult = rewriter.create( + loc, outputType, result, outputReassocIndices); + + rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); + + return std::make_pair(img2ColTensor.getOperation(), + reshapedResult.getOperation()); +} + namespace { class ConvertConv2DNhwcHwcf final @@ -534,12 +669,25 @@ class ConvertConv2DNchwFchw final return success(); } }; + +class ConvertConv2DNhwcFhwc final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, + PatternRewriter &rewriter) const override { + if (failed(rewriteInIm2Col(rewriter, convOp))) + return failure(); + return success(); + } +}; } // end anonymous namespace void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.insert(context); + ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context); } } // end namespace linalg } // end namespace mlir diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir index 657cf83f25460..b2470ed7b7480 100644 --- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir +++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir @@ -279,6 +279,76 @@ transform.sequence failures(propagate) { // ----- +// CHECK: IR printer: tensor_producer +// CHECK-NEXT: %[[COL_TENSOR:.+]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>] +// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) + +// Collapsed indices. +// CHECK: %[[BINDEX:.+]] = linalg.index 0 : index +// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index +// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index + +// Compute input channel/convolved indices. +// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<(d0) -> (d0 mod 4)>(%[[KINDEX]]) +// CHECK: %[[CONVH:.+]] = affine.apply affine_map<(d0, d1) -> (d0 floordiv 14 + d1 floordiv 12)>(%[[MINDEX]], %[[KINDEX]]) +// CHECK: %[[CONVW:.+]] = affine.apply affine_map<(d0, d1) -> (d0 mod 14 + (d1 mod 12) floordiv 4)>(%[[MINDEX]], %[[KINDEX]]) + +// Extract from the input tensor. +// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract +// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32> +// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32 + +// CHECK: IR printer: transformed +// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: @conv_2d_nhwc_fhwc +// CHECK-SAME: %[[INPUT:.+]]: tensor<1x16x16x4xf32> +// CHECK-SAME: %[[FILTER:.+]]: tensor<16x3x3x4xf32> +// CHECK-SAME: %[[OUTPUT:.+]]: tensor<1x14x14x16xf32> +// CHECK-DAG: %[[COLLAPSED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<16x3x3x4xf32> into tensor<16x36xf32> +// CHECK-DAG: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0], [1, 2], [3]] : tensor<1x14x14x16xf32> into tensor<1x196x16xf32> +// CHECK: %[[INIT_COL_TENSOR:.+]] = tensor.empty() : tensor<1x196x36xf32> +// CHECK: %[[COL_TENSOR:.+]] = linalg.generic +// CHECK-SAME: #[[MAP0]] +// CHECK: ^bb0(%[[OUT_DATA:.+]]: f32) +// CHECK: linalg.yield %{{.+}} : f32 +// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic +// CHECK-SAME: #[[MAP1]] +// CHECK-SAME: #[[MAP2]] +// CHECK-SAME: #[[MAP3]] +// CHECK-SAME: ins(%[[COL_TENSOR]], %[[COLLAPSED_FILTER]] : tensor<1x196x36xf32>, tensor<16x36xf32>) +// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xf32>) +// CHECK: ^bb0(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32) +// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32 +// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 +// CHECK: } -> tensor<1x196x16xf32> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> +// CHECK: return %[[RESULT]] + +func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op + transform.print %transformed {name = "transformed"}: !transform.any_op +} + +// ----- + // Check for signed extend when the input type is smaller than the accumulator type. // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>