-
Notifications
You must be signed in to change notification settings - Fork 759
Open
Labels
codegenShared code generation infrastructure and dialectsShared code generation infrastructure and dialectscodegen/rocmROCm code generation compiler backend (HIP/HSA)ROCm code generation compiler backend (HIP/HSA)
Description
Currently torch-mlir
expands var_mean
to a separate mean and variance calculation: https://github.com/llvm/torch-mlir/blob/dee515896cfb7e9738dfb570f4cd75d33f4b1fad/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp#L10027-L10046
For example in batch norm we have:
input_f32: "f32[128, 384, 24, 48][442368, 1, 18432, 384]" = ...
var_mean = torch.ops.aten.var_mean.correction(input_f32, [0, 2, 3], correction = 0, keepdim = True)
After fusions etc. in IREE, this looks like:
// -----// IR Dump Before FormDispatchRegionsPass (iree-dispatch-creation-form-dispatch-regions) //----- //
util.func public @fused_op_var_mean_getitem_getitem_1_190c3176aebda633e3fca3fa69e6b943d2e0b7d3_128x384x24x48xfloat32_perm_0231$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> (!hal.buffer_view, !hal.buffer_view) attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.474560e+05 : f32
%0 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<128x24x48x384xf32>
%1 = tensor.empty() : tensor<384xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<384xf32>) -> tensor<384xf32>
// first pass over input - compute mean
%3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0)>],
iterator_types = ["parallel", "reduction", "reduction", "reduction"]
} ins(%0 : tensor<128x24x48x384xf32>) outs(%2 : tensor<384xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.addf %in, %out : f32
linalg.yield %10 : f32
} -> tensor<384xf32>
%4 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%3 : tensor<384xf32>) outs(%1 : tensor<384xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.divf %in, %cst_0 : f32
linalg.yield %10 : f32
} -> tensor<384xf32>
// second pass over input - compute variance
%5 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d0)>],
iterator_types = ["parallel", "reduction", "reduction", "reduction"]
} ins(%0, %4 : tensor<128x24x48x384xf32>, tensor<384xf32>) outs(%2 : tensor<384xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%10 = arith.subf %in, %in_2 : f32
%11 = arith.mulf %10, %10 : f32
%12 = arith.addf %11, %out : f32
linalg.yield %12 : f32
} -> tensor<384xf32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%5 : tensor<384xf32>) outs(%1 : tensor<384xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.divf %in, %cst_0 : f32
linalg.yield %10 : f32
} -> tensor<384xf32>
%expanded = tensor.expand_shape %6 [[0, 1, 2, 3]] output_shape [1, 1, 1, 384] : tensor<384xf32> into tensor<1x1x1x384xf32>
%expanded_1 = tensor.expand_shape %4 [[0, 1, 2, 3]] output_shape [1, 1, 1, 384] : tensor<384xf32> into tensor<1x1x1x384xf32>
%7:2 = hal.tensor.barrier join(%expanded, %expanded_1 : tensor<1x1x1x384xf32>, tensor<1x1x1x384xf32>) => %arg2 : !hal.fence
%8 = hal.tensor.export %7#0 : tensor<1x1x1x384xf32> -> !hal.buffer_view
%9 = hal.tensor.export %7#1 : tensor<1x1x1x384xf32> -> !hal.buffer_view
util.return %8, %9 : !hal.buffer_view, !hal.buffer_view
}
Instead we should try lowering this to a single pass implementation, however there may be issues with vectorization as this would have a single linalg op performing two reductions.
Implementation references:
Metadata
Metadata
Assignees
Labels
codegenShared code generation infrastructure and dialectsShared code generation infrastructure and dialectscodegen/rocmROCm code generation compiler backend (HIP/HSA)ROCm code generation compiler backend (HIP/HSA)