diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h index 3b4bb34105581..ae5a68a6be157 100644 --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -32,6 +32,7 @@ class FuncOp; namespace memref { class AllocOp; class AllocaOp; +class ReinterpretCastOp; } // namespace memref namespace affine { @@ -243,15 +244,16 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef symbolOperands = {}, bool allowNonDereferencingOps = false); -/// Rewrites the memref defined by this alloc op to have an identity layout map -/// and updates all its indexing uses. Returns failure if any of its uses -/// escape (while leaving the IR in a valid state). +/// Rewrites the memref defined by alloc or reinterpret_cast op to have an +/// identity layout map and updates all its indexing uses. Returns failure if +/// any of its uses escape (while leaving the IR in a valid state). template LogicalResult normalizeMemRef(AllocLikeOp op); extern template LogicalResult normalizeMemRef(memref::AllocaOp op); extern template LogicalResult normalizeMemRef(memref::AllocOp op); +LogicalResult normalizeMemRef(memref::ReinterpretCastOp op); /// Normalizes `memrefType` so that the affine layout map of the memref is /// transformed to an identity map with a new shape being computed for the diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 11798b99fa879..ef470c30e680e 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1098,90 +1098,12 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo, op->erase(); } -// Private helper function to transform memref.load with reduced rank. -// This function will modify the indices of the memref.load to match the -// newMemRef. -LogicalResult transformMemRefLoadWithReducedRank( - Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos, - ArrayRef extraIndices, ArrayRef extraOperands, - ArrayRef symbolOperands, AffineMap indexRemap) { - unsigned oldMemRefRank = cast(oldMemRef.getType()).getRank(); - unsigned newMemRefRank = cast(newMemRef.getType()).getRank(); - unsigned oldMapNumInputs = oldMemRefRank; - SmallVector oldMapOperands( - op->operand_begin() + memRefOperandPos + 1, - op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); - SmallVector oldMemRefOperands; - oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end()); - SmallVector remapOperands; - remapOperands.reserve(extraOperands.size() + oldMemRefRank + - symbolOperands.size()); - remapOperands.append(extraOperands.begin(), extraOperands.end()); - remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); - remapOperands.append(symbolOperands.begin(), symbolOperands.end()); - - SmallVector remapOutputs; - remapOutputs.reserve(oldMemRefRank); - SmallVector affineApplyOps; - - OpBuilder builder(op); - - if (indexRemap && - indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { - // Remapped indices. - for (auto resultExpr : indexRemap.getResults()) { - auto singleResMap = AffineMap::get( - indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); - auto afOp = builder.create(op->getLoc(), singleResMap, - remapOperands); - remapOutputs.push_back(afOp); - affineApplyOps.push_back(afOp); - } - } else { - // No remapping specified. - remapOutputs.assign(remapOperands.begin(), remapOperands.end()); - } - - SmallVector newMapOperands; - newMapOperands.reserve(newMemRefRank); - - // Prepend 'extraIndices' in 'newMapOperands'. - for (Value extraIndex : extraIndices) { - assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && - "invalid memory op index"); - newMapOperands.push_back(extraIndex); - } - - // Append 'remapOutputs' to 'newMapOperands'. - newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); - - // Create new fully composed AffineMap for new op to be created. - assert(newMapOperands.size() == newMemRefRank); - - OperationState state(op->getLoc(), op->getName()); - // Construct the new operation using this memref. - state.operands.reserve(newMapOperands.size() + extraIndices.size()); - state.operands.push_back(newMemRef); - - // Insert the new memref map operands. - state.operands.append(newMapOperands.begin(), newMapOperands.end()); - - state.types.reserve(op->getNumResults()); - for (auto result : op->getResults()) - state.types.push_back(result.getType()); - - // Copy over the attributes from the old operation to the new operation. - for (auto namedAttr : op->getAttrs()) { - state.attributes.push_back(namedAttr); - } - - // Create the new operation. - auto *repOp = builder.create(state); - op->replaceAllUsesWith(repOp); - op->erase(); - - return success(); +// Checks if `op` is non dereferencing. +// TODO: This hardcoded check will be removed once the right interface is added. +static bool isDereferencingOp(Operation *op) { + return isa(op); } + // Perform the replacement in `op`. LogicalResult mlir::affine::replaceAllMemRefUsesWith( Value oldMemRef, Value newMemRef, Operation *op, @@ -1216,53 +1138,53 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( if (usePositions.empty()) return success(); - if (usePositions.size() > 1) { - // TODO: extend it for this case when needed (rare). - assert(false && "multiple dereferencing uses in a single op not supported"); - return failure(); - } - unsigned memRefOperandPos = usePositions.front(); OpBuilder builder(op); // The following checks if op is dereferencing memref and performs the access // index rewrites. - auto affMapAccInterface = dyn_cast(op); - if (!affMapAccInterface) { + if (!isDereferencingOp(op)) { if (!allowNonDereferencingOps) { // Failure: memref used in a non-dereferencing context (potentially // escapes); no replacement in these cases unless allowNonDereferencingOps // is set. return failure(); } + for (unsigned pos : usePositions) + op->setOperand(pos, newMemRef); + return success(); + } - // Check if it is a memref.load - auto memrefLoad = dyn_cast(op); - bool isReductionLike = - indexRemap.getNumResults() < indexRemap.getNumInputs(); - if (!memrefLoad || !isReductionLike) { - op->setOperand(memRefOperandPos, newMemRef); - return success(); - } + if (usePositions.size() > 1) { + // TODO: extend it for this case when needed (rare). + LLVM_DEBUG(llvm::dbgs() + << "multiple dereferencing uses in a single op not supported"); + return failure(); + } - return transformMemRefLoadWithReducedRank( - op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands, - symbolOperands, indexRemap); + // Perform index rewrites for the dereferencing op and then replace the op. + SmallVector oldMapOperands; + AffineMap oldMap; + unsigned oldMemRefNumIndices = oldMemRefRank; + auto startIdx = op->operand_begin() + memRefOperandPos + 1; + auto affMapAccInterface = dyn_cast(op); + if (affMapAccInterface) { + // If `op` implements AffineMapAccessInterface, we can get the indices by + // quering the number of map operands from the operand list from a certain + // offset (`memRefOperandPos` in this case). + NamedAttribute oldMapAttrPair = + affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef); + oldMap = cast(oldMapAttrPair.getValue()).getValue(); + oldMemRefNumIndices = oldMap.getNumInputs(); } - // Perform index rewrites for the dereferencing op and then replace the op - NamedAttribute oldMapAttrPair = - affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef); - AffineMap oldMap = cast(oldMapAttrPair.getValue()).getValue(); - unsigned oldMapNumInputs = oldMap.getNumInputs(); - SmallVector oldMapOperands( - op->operand_begin() + memRefOperandPos + 1, - op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); + oldMapOperands.assign(startIdx, startIdx + oldMemRefNumIndices); // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. SmallVector oldMemRefOperands; SmallVector affineApplyOps; oldMemRefOperands.reserve(oldMemRefRank); - if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { + if (affMapAccInterface && + oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { for (auto resultExpr : oldMap.getResults()) { auto singleResMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr); @@ -1287,7 +1209,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( SmallVector remapOutputs; remapOutputs.reserve(oldMemRefRank); - if (indexRemap && indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { // Remapped indices. @@ -1303,7 +1224,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( // No remapping specified. remapOutputs.assign(remapOperands.begin(), remapOperands.end()); } - SmallVector newMapOperands; newMapOperands.reserve(newMemRefRank); @@ -1338,13 +1258,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( state.operands.push_back(newMemRef); // Insert the new memref map operands. - state.operands.append(newMapOperands.begin(), newMapOperands.end()); + if (affMapAccInterface) { + state.operands.append(newMapOperands.begin(), newMapOperands.end()); + } else { + // In the case of dereferencing ops not implementing + // AffineMapAccessInterface, we need to apply the values of `newMapOperands` + // to the `newMap` to get the correct indices. + for (unsigned i = 0; i < newMemRefRank; i++) { + state.operands.push_back(builder.create( + op->getLoc(), + AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(), + newMap.getResult(i)), + newMapOperands)); + } + } // Insert the remaining operands unmodified. + unsigned oldMapNumInputs = oldMapOperands.size(); state.operands.append(op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs, op->operand_end()); - // Result types don't change. Both memref's are of the same elemental type. state.types.reserve(op->getNumResults()); for (auto result : op->getResults()) @@ -1353,7 +1286,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( // Add attribute for 'newMap', other Attributes do not change. auto newMapAttr = AffineMapAttr::get(newMap); for (auto namedAttr : op->getAttrs()) { - if (namedAttr.getName() == oldMapAttrPair.getName()) + if (affMapAccInterface && + namedAttr.getName() == + affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef).getName()) state.attributes.push_back({namedAttr.getName(), newMapAttr}); else state.attributes.push_back(namedAttr); @@ -1845,6 +1780,94 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) { return success(); } +LogicalResult +mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) { + MemRefType memrefType = reinterpretCastOp.getType(); + AffineMap oldLayoutMap = memrefType.getLayout().getAffineMap(); + Value oldMemRef = reinterpretCastOp.getResult(); + + // If `oldLayoutMap` is identity, `memrefType` is already normalized. + if (oldLayoutMap.isIdentity()) + return success(); + + // Fetch a new memref type after normalizing the old memref to have an + // identity map layout. + MemRefType newMemRefType = normalizeMemRefType(memrefType); + if (newMemRefType == memrefType) + // `oldLayoutMap` couldn't be transformed to an identity map. + return failure(); + + uint64_t newRank = newMemRefType.getRank(); + SmallVector mapOperands(oldLayoutMap.getNumDims() + + oldLayoutMap.getNumSymbols()); + SmallVector oldStrides = reinterpretCastOp.getStrides(); + Location loc = reinterpretCastOp.getLoc(); + // As `newMemRefType` is normalized, it is unit strided. + SmallVector newStaticStrides(newRank, 1); + SmallVector newStaticOffsets(newRank, 0); + ArrayRef oldShape = memrefType.getShape(); + ValueRange oldSizes = reinterpretCastOp.getSizes(); + unsigned idx = 0; + SmallVector newStaticSizes; + OpBuilder b(reinterpretCastOp); + // Collect the map operands which will be used to compute the new normalized + // memref shape. + for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) { + if (memrefType.isDynamicDim(i)) + mapOperands[i] = + b.create(loc, oldSizes[0].getType(), oldSizes[idx++], + b.create(loc, 1)); + else + mapOperands[i] = b.create(loc, oldShape[i] - 1); + } + for (unsigned i = 0, e = oldStrides.size(); i < e; i++) + mapOperands[memrefType.getRank() + i] = oldStrides[i]; + SmallVector newSizes; + ArrayRef newShape = newMemRefType.getShape(); + // Compute size along all the dimensions of the new normalized memref. + for (unsigned i = 0; i < newRank; i++) { + if (!newMemRefType.isDynamicDim(i)) + continue; + newSizes.push_back(b.create( + loc, + AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(), + oldLayoutMap.getResult(i)), + mapOperands)); + } + for (unsigned i = 0, e = newSizes.size(); i < e; i++) { + newSizes[i] = + b.create(loc, newSizes[i].getType(), newSizes[i], + b.create(loc, 1)); + } + // Create the new reinterpret_cast op. + auto newReinterpretCast = b.create( + loc, newMemRefType, reinterpretCastOp.getSource(), + /*offsets=*/ValueRange(), newSizes, + /*strides=*/ValueRange(), + /*static_offsets=*/newStaticOffsets, + /*static_sizes=*/newShape, + /*static_strides=*/newStaticStrides); + + // Replace all uses of the old memref. + if (failed(replaceAllMemRefUsesWith(oldMemRef, + /*newMemRef=*/newReinterpretCast, + /*extraIndices=*/{}, + /*indexRemap=*/oldLayoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/oldStrides, + /*domOpFilter=*/nullptr, + /*postDomOpFilter=*/nullptr, + /*allowNonDereferencingOps=*/true))) { + // If it failed (due to escapes for example), bail out. + newReinterpretCast.erase(); + return failure(); + } + + oldMemRef.replaceAllUsesWith(newReinterpretCast); + reinterpretCastOp.erase(); + return success(); +} + template LogicalResult mlir::affine::normalizeMemRef(memref::AllocaOp op); template LogicalResult diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp index 7abd9d17f5aa1..b408962690810 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -30,6 +30,7 @@ namespace memref { using namespace mlir; using namespace mlir::affine; +using namespace mlir::memref; namespace { @@ -159,7 +160,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { return true; if (funcOp - .walk([&](memref::AllocOp allocOp) -> WalkResult { + .walk([&](AllocOp allocOp) -> WalkResult { Value oldMemRef = allocOp.getResult(); if (!allocOp.getType().getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) @@ -170,7 +171,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { return false; if (funcOp - .walk([&](memref::AllocaOp allocaOp) -> WalkResult { + .walk([&](AllocaOp allocaOp) -> WalkResult { Value oldMemRef = allocaOp.getResult(); if (!allocaOp.getType().getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) @@ -341,22 +342,31 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, } /// Normalizes the memrefs within a function which includes those arising as a -/// result of AllocOps, AllocaOps, CallOps and function's argument. The ModuleOp -/// argument is used to help update function's signature after normalization. +/// result of AllocOps, AllocaOps, CallOps, ReinterpretCastOps and function's +/// argument. The ModuleOp argument is used to help update function's signature +/// after normalization. void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp) { // Turn memrefs' non-identity layouts maps into ones with identity. Collect - // alloc/alloca ops first and then process since normalizeMemRef - // replaces/erases ops during memref rewriting. - SmallVector allocOps; - funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); }); - for (memref::AllocOp allocOp : allocOps) + // alloc, alloca ops and reinterpret_cast ops first and then process since + // normalizeMemRef replaces/erases ops during memref rewriting. + SmallVector allocOps; + SmallVector allocaOps; + SmallVector reinterpretCastOps; + funcOp.walk([&](Operation *op) { + if (auto allocOp = dyn_cast(op)) + allocOps.push_back(allocOp); + else if (auto allocaOp = dyn_cast(op)) + allocaOps.push_back(allocaOp); + else if (auto reinterpretCastOp = dyn_cast(op)) + reinterpretCastOps.push_back(reinterpretCastOp); + }); + for (AllocOp allocOp : allocOps) (void)normalizeMemRef(allocOp); - - SmallVector allocaOps; - funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); }); - for (memref::AllocaOp allocaOp : allocaOps) + for (AllocaOp allocaOp : allocaOps) (void)normalizeMemRef(allocaOp); + for (ReinterpretCastOp reinterpretCastOp : reinterpretCastOps) + (void)normalizeMemRef(reinterpretCastOp); // We use this OpBuilder to create new memref layout later. OpBuilder b(funcOp); diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir index 3bede131325a7..344da4e5e2462 100644 --- a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir +++ b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir @@ -165,3 +165,34 @@ func.func @prefetch_normalize(%arg0: memref<512xf32, affine_map<(d0) -> (d0 floo } return } + +#map_strided = affine_map<(d0, d1) -> (d0 * 7 + d1)> + +// CHECK-LABEL: test_reinterpret_cast +func.func @test_reinterpret_cast(%arg0: memref<5x7xf32>, %arg1: memref<5x7xf32>, %arg2: memref<5x7xf32>) { + %0 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [5, 7], strides: [7, 1] : memref<5x7xf32> to memref<5x7xf32, #map_strided> + // CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [35], strides: [1] : memref<5x7xf32> to memref<35xf32> + affine.for %arg5 = 0 to 5 { + affine.for %arg6 = 0 to 7 { + %1 = affine.load %0[%arg5, %arg6] : memref<5x7xf32, #map_strided> + // CHECK: affine.load %reinterpret_cast[%{{.*}} * 7 + %{{.*}}] : memref<35xf32> + %2 = affine.load %arg1[%arg5, %arg6] : memref<5x7xf32> + %3 = arith.subf %1, %2 : f32 + affine.store %3, %arg2[%arg5, %arg6] : memref<5x7xf32> + } + } + return +} + +// CHECK-LABEL: reinterpret_cast_non_zero_offset +func.func @reinterpret_cast_non_zero_offset(%arg0: index, %arg1: memref<1x10x17xi32, strided<[?, ?, ?], offset: ?>>, %arg2: memref<1x10x17xi32, strided<[?, ?, ?], offset: ?>>, %arg3: memref<1x10x17xi32, strided<[?, ?, ?], offset: ?>>) -> (memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32>) { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x10x17xi32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x17xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x10x17xf32> + cf.br ^bb3 +^bb3: // pred: ^bb1 + // CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [32], strides: [1] : memref<2x17xf32> to memref<32xf32> + // CHECK: return %[[REINTERPRET_CAST]], %[[REINTERPRET_CAST]], %{{.*}}, %{{.*}}, %{{.*}} : memref<32xf32>, memref<32xf32>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32> + %reinterpret_cast = memref.reinterpret_cast %alloc_0 to offset: [27], sizes: [1, 5], strides: [17, 1] : memref<2x17xf32> to memref<1x5xf32, strided<[17, 1], offset: 27>> + return %reinterpret_cast, %reinterpret_cast, %alloc_0, %alloc, %alloc_1 : memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<1x5xf32, strided<[17, 1], offset: 27>>, memref<2x17xf32>, memref<1x10x17xi32>, memref<1x10x17xf32> +} diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir index e93a1a4ebae53..440f4776424cc 100644 --- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir +++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir @@ -3,8 +3,8 @@ // This file tests whether the memref type having non-trivial map layouts // are normalized to trivial (identity) layouts. -// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> ((d0 mod 2) * 2 + d1 mod 2 + (d0 floordiv 2) * 4 + (d1 floordiv 2) * 8)> -// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 mod 2 + (d1 mod 2) * 2 + (d0 floordiv 2) * 8 + (d1 floordiv 2) * 4)> +// CHECK-DAG: #[[$REDUCE_MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 + (d1 floordiv 2) * 6)> +// CHECK-DAG: #[[$REDUCE_MAP2:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 2 + (d0 floordiv 2) * 6)> // CHECK-DAG: #[[$REDUCE_MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 4 + d1)> // CHECK-LABEL: func @permute()