Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Compile aten.var_mean to a single-pass implemention #22039

@rkayaith

Description

@rkayaith

Currently torch-mlir expands var_mean to a separate mean and variance calculation: https://github.com/llvm/torch-mlir/blob/dee515896cfb7e9738dfb570f4cd75d33f4b1fad/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp#L10027-L10046

For example in batch norm we have:

input_f32: "f32[128, 384, 24, 48][442368, 1, 18432, 384]" = ...
var_mean = torch.ops.aten.var_mean.correction(input_f32, [0, 2, 3], correction = 0, keepdim = True)

After fusions etc. in IREE, this looks like:

// -----// IR Dump Before FormDispatchRegionsPass (iree-dispatch-creation-form-dispatch-regions) //----- //
util.func public @fused_op_var_mean_getitem_getitem_1_190c3176aebda633e3fca3fa69e6b943d2e0b7d3_128x384x24x48xfloat32_perm_0231$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> (!hal.buffer_view, !hal.buffer_view) attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant 1.474560e+05 : f32
  %0 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<128x24x48x384xf32>
  %1 = tensor.empty() : tensor<384xf32>
  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<384xf32>) -> tensor<384xf32>

  // first pass over input - compute mean
  %3 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0)>],
    iterator_types = ["parallel", "reduction", "reduction", "reduction"]
  } ins(%0 : tensor<128x24x48x384xf32>) outs(%2 : tensor<384xf32>) {
  ^bb0(%in: f32, %out: f32):
    %10 = arith.addf %in, %out : f32
    linalg.yield %10 : f32
  } -> tensor<384xf32>
  %4 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%3 : tensor<384xf32>) outs(%1 : tensor<384xf32>) {
  ^bb0(%in: f32, %out: f32):
    %10 = arith.divf %in, %cst_0 : f32
    linalg.yield %10 : f32
  } -> tensor<384xf32>

  // second pass over input - compute variance
  %5 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d0)>], 
    iterator_types = ["parallel", "reduction", "reduction", "reduction"]
  } ins(%0, %4 : tensor<128x24x48x384xf32>, tensor<384xf32>) outs(%2 : tensor<384xf32>) {
  ^bb0(%in: f32, %in_2: f32, %out: f32):
    %10 = arith.subf %in, %in_2 : f32
    %11 = arith.mulf %10, %10 : f32
    %12 = arith.addf %11, %out : f32
    linalg.yield %12 : f32
  } -> tensor<384xf32>
  %6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%5 : tensor<384xf32>) outs(%1 : tensor<384xf32>) {
  ^bb0(%in: f32, %out: f32):
    %10 = arith.divf %in, %cst_0 : f32
    linalg.yield %10 : f32
  } -> tensor<384xf32>

  %expanded = tensor.expand_shape %6 [[0, 1, 2, 3]] output_shape [1, 1, 1, 384] : tensor<384xf32> into tensor<1x1x1x384xf32>
  %expanded_1 = tensor.expand_shape %4 [[0, 1, 2, 3]] output_shape [1, 1, 1, 384] : tensor<384xf32> into tensor<1x1x1x384xf32>
  %7:2 = hal.tensor.barrier join(%expanded, %expanded_1 : tensor<1x1x1x384xf32>, tensor<1x1x1x384xf32>) => %arg2 : !hal.fence
  %8 = hal.tensor.export %7#0 : tensor<1x1x1x384xf32> -> !hal.buffer_view
  %9 = hal.tensor.export %7#1 : tensor<1x1x1x384xf32> -> !hal.buffer_view
  util.return %8, %9 : !hal.buffer_view, !hal.buffer_view
}

Instead we should try lowering this to a single pass implementation, however there may be issues with vectorization as this would have a single linalg op performing two reductions.

Implementation references:

Metadata

Metadata

Assignees

Labels

codegenShared code generation infrastructure and dialectscodegen/rocmROCm code generation compiler backend (HIP/HSA)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions