Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[mlir] Add memref normalization support for reinterpret_cast op #133417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2025

Conversation

arnab-polymage
Copy link
Contributor

@arnab-polymage arnab-polymage commented Mar 28, 2025

Rewrites the memref defined by reinterpet_cast op to have an identity layout map
and updates all its indexing uses. Also extend replaceAllMemRefUsesWith utility
to work when there are multiple occurences of oldMemRef in op's operand list
when op is non-dereferencing.

Fixes #122090
Fixes #121091

@llvmbot
Copy link
Member

llvmbot commented Mar 28, 2025

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir

Author: Arnab Dutta (arnab-polymage)

Changes

Normalize reinterpret_cast op for statically shaped input and output memrefs.


Patch is 20.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133417.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/Utils.h (+6-3)
  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+139-28)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp (+9)
  • (modified) mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir (-140)
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index ff1900bc8f2eb..1032d4d92b589 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -32,6 +32,7 @@ class FuncOp;
 namespace memref {
 class AllocOp;
 class AllocaOp;
+class ReinterpretCastOp;
 } // namespace memref
 
 namespace affine {
@@ -243,9 +244,9 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
                                        ArrayRef<Value> symbolOperands = {},
                                        bool allowNonDereferencingOps = false);
 
-/// Rewrites the memref defined by this alloc op to have an identity layout map
-/// and updates all its indexing uses. Returns failure if any of its uses
-/// escape (while leaving the IR in a valid state).
+/// Rewrites the memref defined by alloc or reinterpret_cast op to have an
+/// identity layout map and updates all its indexing uses. Returns failure if
+/// any of its uses escape (while leaving the IR in a valid state).
 template <typename AllocLikeOp>
 LogicalResult normalizeMemRef(AllocLikeOp *op);
 extern template LogicalResult
@@ -253,6 +254,8 @@ normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
 extern template LogicalResult
 normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
 
+LogicalResult normalizeMemRef(memref::ReinterpretCastOp *op);
+
 /// Normalizes `memrefType` so that the affine layout map of the memref is
 /// transformed to an identity map with a new shape being computed for the
 /// normalized memref type and returns it. The old memref type is simplify
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 2723cff6900d0..f7b86cfbbd76c 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1101,7 +1101,7 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
 // Private helper function to transform memref.load with reduced rank.
 // This function will modify the indices of the memref.load to match the
 // newMemRef.
-LogicalResult transformMemRefLoadWithReducedRank(
+LogicalResult transformMemRefLoadOrStoreWithReducedRank(
     Operation *op, Value oldMemRef, Value newMemRef, unsigned memRefOperandPos,
     ArrayRef<Value> extraIndices, ArrayRef<Value> extraOperands,
     ArrayRef<Value> symbolOperands, AffineMap indexRemap) {
@@ -1182,6 +1182,14 @@ LogicalResult transformMemRefLoadWithReducedRank(
 
   return success();
 }
+
+// Checks if `op` is non dereferencing.
+// TODO: This hardcoded check will be removed once the right interface is added.
+static bool isDereferencingOp(Operation *op) {
+  return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(
+      op);
+}
+
 // Perform the replacement in `op`.
 LogicalResult mlir::affine::replaceAllMemRefUsesWith(
     Value oldMemRef, Value newMemRef, Operation *op,
@@ -1228,41 +1236,44 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
   // The following checks if op is dereferencing memref and performs the access
   // index rewrites.
   auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
-  if (!affMapAccInterface) {
+  if (!isDereferencingOp(op)) {
     if (!allowNonDereferencingOps) {
       // Failure: memref used in a non-dereferencing context (potentially
       // escapes); no replacement in these cases unless allowNonDereferencingOps
       // is set.
       return failure();
     }
+    op->setOperand(memRefOperandPos, newMemRef);
+    return success();
+  }
 
-    // Check if it is a memref.load
-    auto memrefLoad = dyn_cast<memref::LoadOp>(op);
-    bool isReductionLike =
-        indexRemap.getNumResults() < indexRemap.getNumInputs();
-    if (!memrefLoad || !isReductionLike) {
-      op->setOperand(memRefOperandPos, newMemRef);
-      return success();
-    }
-
-    return transformMemRefLoadWithReducedRank(
-        op, oldMemRef, newMemRef, memRefOperandPos, extraIndices, extraOperands,
-        symbolOperands, indexRemap);
+  // Perform index rewrites for the dereferencing op and then replace the op.
+  SmallVector<Value, 4> oldMapOperands;
+  AffineMap oldMap;
+  unsigned oldMemRefNumIndices = oldMemRefRank;
+  if (affMapAccInterface) {
+    // If `op` implements AffineMapAccessInterface, we can get the indices by
+    // quering the number of map operands from the operand list from a certain
+    // offset (`memRefOperandPos` in this case).
+    NamedAttribute oldMapAttrPair =
+        affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
+    oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
+    oldMemRefNumIndices = oldMap.getNumInputs();
+    oldMapOperands.assign(op->operand_begin() + memRefOperandPos + 1,
+                          op->operand_begin() + memRefOperandPos + 1 +
+                              oldMemRefNumIndices);
+  } else {
+    oldMapOperands.assign(op->operand_begin() + memRefOperandPos + 1,
+                          op->operand_begin() + memRefOperandPos + 1 +
+                              oldMemRefRank);
   }
-  // Perform index rewrites for the dereferencing op and then replace the op
-  NamedAttribute oldMapAttrPair =
-      affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
-  AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
-  unsigned oldMapNumInputs = oldMap.getNumInputs();
-  SmallVector<Value, 4> oldMapOperands(
-      op->operand_begin() + memRefOperandPos + 1,
-      op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
 
   // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
   SmallVector<Value, 4> oldMemRefOperands;
   SmallVector<Value, 4> affineApplyOps;
   oldMemRefOperands.reserve(oldMemRefRank);
-  if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
+  if (affMapAccInterface &&
+      oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
     for (auto resultExpr : oldMap.getResults()) {
       auto singleResMap = AffineMap::get(oldMap.getNumDims(),
                                          oldMap.getNumSymbols(), resultExpr);
@@ -1287,7 +1298,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
 
   SmallVector<Value, 4> remapOutputs;
   remapOutputs.reserve(oldMemRefRank);
-
   if (indexRemap &&
       indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
     // Remapped indices.
@@ -1303,7 +1313,6 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
     // No remapping specified.
     remapOutputs.assign(remapOperands.begin(), remapOperands.end());
   }
-
   SmallVector<Value, 4> newMapOperands;
   newMapOperands.reserve(newMemRefRank);
 
@@ -1338,13 +1347,26 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
   state.operands.push_back(newMemRef);
 
   // Insert the new memref map operands.
-  state.operands.append(newMapOperands.begin(), newMapOperands.end());
+  if (affMapAccInterface) {
+    state.operands.append(newMapOperands.begin(), newMapOperands.end());
+  } else {
+    // In the case of dereferencing ops not implementing
+    // AffineMapAccessInterface, we need to apply the values of `newMapOperands`
+    // to the `newMap` to get the correct indices.
+    for (unsigned i = 0; i < newMemRefRank; i++)
+      state.operands.push_back(builder.create<AffineApplyOp>(
+          op->getLoc(),
+          AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(),
+                         newMap.getResult(i)),
+          newMapOperands));
+  }
 
   // Insert the remaining operands unmodified.
+  unsigned oldMapNumInputs = oldMapOperands.size();
+
   state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
                             oldMapNumInputs,
                         op->operand_end());
-
   // Result types don't change. Both memref's are of the same elemental type.
   state.types.reserve(op->getNumResults());
   for (auto result : op->getResults())
@@ -1353,7 +1375,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
   // Add attribute for 'newMap', other Attributes do not change.
   auto newMapAttr = AffineMapAttr::get(newMap);
   for (auto namedAttr : op->getAttrs()) {
-    if (namedAttr.getName() == oldMapAttrPair.getName())
+    if (affMapAccInterface &&
+        namedAttr.getName() ==
+            affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef).getName())
       state.attributes.push_back({namedAttr.getName(), newMapAttr});
     else
       state.attributes.push_back(namedAttr);
@@ -1846,6 +1870,93 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
   return success();
 }
 
+LogicalResult
+mlir::affine::normalizeMemRef(memref::ReinterpretCastOp *reinterpretCastOp) {
+  MemRefType memrefType = reinterpretCastOp->getType();
+  AffineMap oldLayoutMap = memrefType.getLayout().getAffineMap();
+  Value oldMemRef = reinterpretCastOp->getResult();
+
+  // If `oldLayoutMap` is identity, `memrefType` is already normalized.
+  if (oldLayoutMap.isIdentity())
+    return success();
+
+  // Fetch a new memref type after normalizing the old memref to have an
+  // identity map layout.
+  MemRefType newMemRefType =
+      normalizeMemRefType(memrefType);
+  if (newMemRefType == memrefType)
+    // `oldLayoutMap` couldn't be transformed to an identity map.
+    return failure();
+
+  uint64_t newRank = newMemRefType.getRank();
+  SmallVector<Value> mapOperands(oldLayoutMap.getNumDims() +
+                                 oldLayoutMap.getNumSymbols());
+  SmallVector<Value> oldStrides = reinterpretCastOp->getStrides();
+  Location loc = reinterpretCastOp->getLoc();
+  // As `newMemRefType` is normalized, it is unit strided.
+  SmallVector<int64_t> newStaticStrides(newRank, 1);
+  ArrayRef<int64_t> oldShape = memrefType.getShape();
+  mlir::ValueRange oldSizes = reinterpretCastOp->getSizes();
+  unsigned idx = 0;
+  SmallVector<int64_t> newStaticSizes;
+  OpBuilder b(*reinterpretCastOp);
+  // Collectthe map operands which will be used to compute the new normalized
+  // memref shape.
+  for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) {
+    if (oldShape[i] == ShapedType::kDynamic)
+      mapOperands[i] =
+          b.create<arith::SubIOp>(loc, oldSizes[0].getType(), oldSizes[idx++],
+                                  b.create<arith::ConstantIndexOp>(loc, 1));
+    else
+      mapOperands[i] = b.create<arith::ConstantIndexOp>(loc, oldShape[i] - 1);
+  }
+  for (unsigned i = 0, e = oldStrides.size(); i < e; i++)
+    mapOperands[memrefType.getRank() + i] = oldStrides[i];
+  SmallVector<Value> newSizes;
+  ArrayRef<int64_t> newShape = newMemRefType.getShape();
+  // Compute size along all the dimensions of the new normalized memref.
+  for (unsigned i = 0; i < newRank; i++) {
+    if (newShape[i] != ShapedType::kDynamic)
+      continue;
+    newSizes.push_back(b.create<AffineApplyOp>(
+        loc,
+        AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(),
+                       oldLayoutMap.getResult(i)),
+        mapOperands));
+  }
+  for (unsigned i = 0, e = newSizes.size(); i < e; i++)
+    newSizes[i] =
+        b.create<arith::AddIOp>(loc, newSizes[i].getType(), newSizes[i],
+                                b.create<arith::ConstantIndexOp>(loc, 1));
+  // Create the new reinterpret_cast op.
+  memref::ReinterpretCastOp newReinterpretCast =
+      b.create<memref::ReinterpretCastOp>(
+          loc, newMemRefType, reinterpretCastOp->getSource(),
+          reinterpretCastOp->getOffsets(), newSizes, mlir::ValueRange(),
+          /*static_offsets=*/reinterpretCastOp->getStaticOffsets(),
+          /*static_sizes=*/newShape,
+          /*static_strides=*/newStaticStrides);
+
+  // Replace all uses of the old memref.
+  if (failed(replaceAllMemRefUsesWith(oldMemRef,
+                                      /*newMemRef=*/newReinterpretCast,
+                                      /*extraIndices=*/{},
+                                      /*indexRemap=*/oldLayoutMap,
+                                      /*extraOperands=*/{},
+                                      /*symbolOperands=*/oldStrides,
+                                      /*domOpFilter=*/nullptr,
+                                      /*postDomOpFilter=*/nullptr,
+                                      /*allowNonDereferencingOps=*/true))) {
+    // If it failed (due to escapes for example), bail out.
+    newReinterpretCast->erase();
+    return failure();
+  }
+
+  oldMemRef.replaceAllUsesWith(newReinterpretCast);
+  reinterpretCastOp->erase();
+  return success();
+}
+
 template LogicalResult
 mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
 template LogicalResult
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index 08b853fe65b85..b8d8a99c33084 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -363,6 +363,15 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
   for (memref::AllocaOp allocaOp : allocaOps)
     (void)normalizeMemRef(&allocaOp);
 
+  // Turn memrefs' non-identity layouts maps into ones with identity. Collect
+  // reinterpret_cast ops first and then process since normalizeMemRef
+  // replaces/erases ops during memref rewriting.
+  SmallVector<memref::ReinterpretCastOp> reinterpretCastOps;
+  funcOp.walk(
+      [&](memref::ReinterpretCastOp op) { reinterpretCastOps.push_back(op); });
+  for (memref::ReinterpretCastOp reinterpretCastOp : reinterpretCastOps)
+    (void)normalizeMemRef(&reinterpretCastOp);
+
   // We use this OpBuilder to create new memref layout later.
   OpBuilder b(funcOp);
 
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
index 3bede131325a7..f02bbde1acd9e 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir
@@ -25,143 +25,3 @@ func.func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
     // CHECK: memref.dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
     return
 }
-
-// Same test with op_nonnorm, with maps in the arguments and the operations in the function.
-
-// CHECK-LABEL: test_nonnorm
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32, #[[MAP:.*]]>)
-func.func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
-    %0 = memref.alloc() : memref<1x16x14x14xf32, #map0>
-    "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
-    memref.dealloc %0 :  memref<1x16x14x14xf32, #map0>
-
-    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32, #[[MAP]]>
-    // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #[[MAP]]>, memref<1x16x14x14xf32, #[[MAP]]>) -> ()
-    // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32, #[[MAP]]>
-    return
-}
-
-// Test with op_nonnorm whose memref map layouts are identity. This op_nonnorm
-// does not block the normalization of other operations.
-
-// CHECK-LABEL: test_nonnorm_identity_layout
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>)
-func.func @test_nonnorm_identity_layout(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
-    %0 = memref.alloc() : memref<1x16x14x14xf32>
-    "test.op_nonnorm"(%0, %0) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
-    "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32>) -> ()
-    memref.dealloc %0 :  memref<1x16x14x14xf32>
-
-    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32>
-    // CHECK: "test.op_nonnorm"(%[[v0]], %[[v0]]) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
-    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32>) -> ()
-    // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32>
-    return
-}
-
-// Test with op_norm, with maps in the operations in the function.
-
-// CHECK-LABEL: test_norm_mix
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>
-func.func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
-    %0 = memref.alloc() : memref<1x16x14x14xf32, #map0>
-    "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> ()
-    memref.dealloc %0 :  memref<1x16x14x14xf32, #map0>
-
-    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x64xf32>
-    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
-    // CHECK: memref.dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
-    return
-}
-
-// Test with maps in load and store ops.
-
-#map_tile = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 32, d2 mod 32, d3 mod 32)>
-
-// CHECK-LABEL: test_load_store
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32>
-func.func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () {
-    %0 = memref.alloc() : memref<1x16x14x14xf32, #map_tile>
-    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x32xf32>
-    %1 = memref.alloc() : memref<1x16x14x14xf32>
-    // CHECK: %[[v1:.*]] = memref.alloc() : memref<1x16x14x14xf32>
-    "test.op_norm"(%0, %1) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) -> ()
-    // CHECK: "test.op_norm"(%[[v0]], %[[v1]]) : (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) -> ()
-    %cst = arith.constant 3.0 : f32
-    affine.for %i = 0 to 1 {
-      affine.for %j = 0 to 16 {
-        affine.for %k = 0 to 14 {
-          affine.for %l = 0 to 14 {
-            %2 = memref.load %1[%i, %j, %k, %l] : memref<1x16x14x14xf32>
-            // CHECK: memref<1x16x14x14xf32>
-            %3 = arith.addf %2, %cst : f32
-            memref.store %3, %arg0[%i, %j, %k, %l] : memref<1x16x14x14xf32>
-            // CHECK: memref<1x16x14x14xf32>
-          }
-        }
-      }
-    }
-    memref.dealloc %0 :  memref<1x16x14x14xf32, #map_tile>
-    // CHECK: memref.dealloc %[[v0]] : memref<1x16x1x1x32x32xf32>
-    memref.dealloc %1 :  memref<1x16x14x14xf32>
-    // CHECK: memref.dealloc %[[v1]] : memref<1x16x14x14xf32>
-    return
-}
-
-// Test with op_norm_ret, with maps in the results of normalizable operation.
-
-// CHECK-LABEL: test_norm_ret
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) {
-func.func @test_norm_ret(%arg0: memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) {
-    %0 = memref.alloc() : memref<1x16x14x14xf32, #map_tile>
-    // CHECK-NEXT: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x32xf32>
-    %1, %2 = "test.op_norm_ret"(%arg0) : (memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>)
-    // CHECK-NEXT: %[[v1:.*]], %[[v2:.*]] = "test.op_norm_ret"
-    // CHECK-SAME: (memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>)
-    "test.op_norm"(%1, %0) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32, #map_tile>) -> ()
-    // CHECK-NEXT: "test.op_norm"
-    // CHECK-SAME: : (memref<1x16x1x1x32x32xf32>, memref<1x16x1x1x32x32xf32>) -> ()
-    memref.dealloc %0 : memref<1x16x14x14xf32, #map_tile>
-    // CHECK-NEXT: memref.dealloc %[[v0]] : memref<1x16x1x1x32x32xf32>
-    return %1, %2 : memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>
-    // CHECK-NEXT: return %[[v1]], %[[v2]] : memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>
-}
-
-// Test with an arbitrary op that references the function symbol.
-
-"test.op_funcref"() {func = @test_norm_mix} : () -> ()
-
-
-// -----
-
-#map_1d_tile = affine_map<(d0) -> (d0 floordiv 32, d0 mod 32)>
-
-// Test with memref.reinterpret_cast
-
-// CHECK-LABEL: test_norm_reinterpret_cast
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1x32xf32>) -> memref<3x1x1xf32> {
-func.func @test_norm_reinterpret_cast(%arg0 : memref<3xf32, #map_1d_tile>) -> (memref<3x1x1xf32>) {
-    %0 = memref.alloc() : memref<3xf32>
-    "test.op_norm"(%arg0, %0) : (memref<3xf32, #map_1d_tile>, memref<3xf32>) -> ()
-    %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [3, 1, 1], strides: [1, 1, 1] : memref<3xf32> to memref<3x1x1xf32>
-    // CHECK: %[[v0:.*]] = memref.alloc() : memref<3xf32>
-    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x32xf32>, memref<3xf32>) -> ()
-    // CHECK: memref.reinterpret_ca...
[truncated]

@arnab-polymage arnab-polymage force-pushed the ornib/normalize_memref_bug branch from e36680a to e5a8ddf Compare March 28, 2025 10:38
Copy link

github-actions bot commented Mar 28, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@arnab-polymage arnab-polymage force-pushed the ornib/normalize_memref_bug branch from b2de115 to 821202f Compare March 31, 2025 04:21
@arnab-polymage arnab-polymage force-pushed the ornib/normalize_memref_bug branch 2 times, most recently from e05179d to 6b8a666 Compare March 31, 2025 06:07
@joker-eph joker-eph changed the title Normalize reinterpret_cast op [mlir] Fix Memref Normalization handling of reinterpret_cast op Apr 3, 2025
@arnab-polymage arnab-polymage changed the title [mlir] Fix Memref Normalization handling of reinterpret_cast op [mlir] Add Memref Normalization support for reinterpret_cast op Apr 3, 2025
@arnab-polymage arnab-polymage force-pushed the ornib/normalize_memref_bug branch 2 times, most recently from 1eda5c1 to 4376b61 Compare April 7, 2025 13:14
@bondhugula
Copy link
Contributor

Fix format as well.

@arnab-polymage arnab-polymage force-pushed the ornib/normalize_memref_bug branch from 4376b61 to 24e4b70 Compare April 10, 2025 04:09
Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking mostly good. Some comments.

@bondhugula bondhugula changed the title [mlir] Add Memref Normalization support for reinterpret_cast op [mlir] Add memref normalization support for reinterpret_cast op Apr 25, 2025
@arnab-polymage arnab-polymage force-pushed the ornib/normalize_memref_bug branch from 24e4b70 to 54520e5 Compare April 28, 2025 08:50
@arnab-polymage arnab-polymage force-pushed the ornib/normalize_memref_bug branch from 54520e5 to 58e5a69 Compare April 28, 2025 10:06
Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks.

@arnab-polymage arnab-polymage force-pushed the ornib/normalize_memref_bug branch 2 times, most recently from 5af4e71 to 3ca4a47 Compare April 30, 2025 06:07
Rewrites the memref defined by reinterpet_cast op to have an identity layout map
and updates all its indexing uses. Also extend `replaceAllMemRefUsesWith` utility
to work when there are multiple occurences of `oldMemRef` in `op`'s operand list
when op is non-dereferencing.
@arnab-polymage arnab-polymage force-pushed the ornib/normalize_memref_bug branch from 3ca4a47 to 0b5f542 Compare April 30, 2025 06:07
@bondhugula bondhugula merged commit 99cb3f7 into llvm:main Apr 30, 2025
11 checks passed
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…#133417)

Rewrites memrefs defined by reinterpet_cast ops to have an identity
layout map and updates all their indexing uses. Also, extend
`replaceAllMemRefUsesWith` utility to work when there are multiple 
occurrences of `oldMemRef` in `op`'s operand list when op is 
non-dereferencing.

Fixes llvm#122090 
Fixes llvm#121091
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…#133417)

Rewrites memrefs defined by reinterpet_cast ops to have an identity
layout map and updates all their indexing uses. Also, extend
`replaceAllMemRefUsesWith` utility to work when there are multiple 
occurrences of `oldMemRef` in `op`'s operand list when op is 
non-dereferencing.

Fixes llvm#122090 
Fixes llvm#121091
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…#133417)

Rewrites memrefs defined by reinterpet_cast ops to have an identity
layout map and updates all their indexing uses. Also, extend
`replaceAllMemRefUsesWith` utility to work when there are multiple 
occurrences of `oldMemRef` in `op`'s operand list when op is 
non-dereferencing.

Fixes llvm#122090 
Fixes llvm#121091
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
…#133417)

Rewrites memrefs defined by reinterpet_cast ops to have an identity
layout map and updates all their indexing uses. Also, extend
`replaceAllMemRefUsesWith` utility to work when there are multiple 
occurrences of `oldMemRef` in `op`'s operand list when op is 
non-dereferencing.

Fixes llvm#122090 
Fixes llvm#121091
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants