Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[mlir][memref] Add runtime verification for memref.atomic_rmw #130414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Mar 8, 2025

Implement runtime verification for memref.atomic_rmw and memref.generic_atomic_rmw. Also add a missing test for memref.store.

@llvmbot
Copy link
Member

llvmbot commented Mar 8, 2025

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Implement runtime verification for memref.atomic_rmw and memref.generic_atomic_rmw.


Full diff: https://github.com/llvm/llvm-project/pull/130414.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+26-19)
  • (added) mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir (+45)
  • (added) mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir (+45)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index c8e7325d7ac89..ceea27a35a225 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -35,6 +35,26 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
   return inBounds;
 }
 
+/// Generate a runtime check to see if the given indices are in-bounds with
+/// respect to the given ranked memref.
+Value generateIndicesInBoundsCheck(OpBuilder &builder, Location loc,
+                                   Value memref, ValueRange indices) {
+  auto memrefType = cast<MemRefType>(memref.getType());
+  assert(memrefType.getRank() == static_cast<int64_t>(indices.size()) &&
+         "rank mismatch");
+  Value cond = builder.create<arith::ConstantOp>(
+      loc, builder.getIntegerAttr(builder.getI1Type(), 1));
+
+  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  for (auto [dim, idx] : llvm::enumerate(indices)) {
+    Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, dim);
+    Value inBounds = generateInBoundsCheck(builder, loc, idx, zero, dimOp);
+    cond = builder.createOrFold<arith::AndIOp>(loc, cond, inBounds);
+  }
+
+  return cond;
+}
+
 struct AssumeAlignmentOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<
           AssumeAlignmentOpInterface, AssumeAlignmentOp> {
@@ -186,26 +206,10 @@ struct LoadStoreOpInterface
   void generateRuntimeVerification(Operation *op, OpBuilder &builder,
                                    Location loc) const {
     auto loadStoreOp = cast<LoadStoreOp>(op);
-
-    auto memref = loadStoreOp.getMemref();
-    auto rank = memref.getType().getRank();
-    if (rank == 0) {
-      return;
-    }
-    auto indices = loadStoreOp.getIndices();
-
-    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-    Value assertCond;
-    for (auto i : llvm::seq<int64_t>(0, rank)) {
-      Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
-      Value inBounds =
-          generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
-      assertCond =
-          i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
-                : inBounds;
-    }
+    Value cond = generateIndicesInBoundsCheck(
+        builder, loc, loadStoreOp.getMemref(), loadStoreOp.getIndices());
     builder.create<cf::AssertOp>(
-        loc, assertCond,
+        loc, cond,
         RuntimeVerifiableOpInterface::generateErrorMessage(
             op, "out-of-bounds access"));
   }
@@ -377,9 +381,12 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
     AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
+    AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
     CastOp::attachInterface<CastOpInterface>(*ctx);
     DimOp::attachInterface<DimOpInterface>(*ctx);
     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
+    GenericAtomicRMWOp::attachInterface<
+        LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
     LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
     ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
     StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
new file mode 100644
index 0000000000000..9f70c5ca66f65
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -test-cf-assert \
+// RUN:     -expand-strided-metadata \
+// RUN:     -lower-affine \
+// RUN:     -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @store_dynamic(%memref: memref<?xf32>, %index: index) {
+  %cst = arith.constant 1.0 : f32
+  memref.atomic_rmw addf %cst, %memref[%index] : (f32, memref<?xf32>) -> f32
+  return
+}
+
+func.func @main() {
+  // Allocate a memref<10xf32>, but disguise it as a memref<5xf32>. This is
+  // necessary because "-test-cf-assert" does not abort the program and we do
+  // not want to segfault when running the test case.
+  %alloc = memref.alloca() : memref<10xf32>
+  %ptr = memref.extract_aligned_pointer_as_index %alloc : memref<10xf32> -> index
+  %ptr_i64 = arith.index_cast %ptr : index to i64
+  %ptr_llvm = llvm.inttoptr %ptr_i64 : i64 to !llvm.ptr
+  %c0 = llvm.mlir.constant(0 : index) : i64
+  %c1 = llvm.mlir.constant(1 : index) : i64
+  %c5 = llvm.mlir.constant(5 : index) : i64
+  %4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %5 = llvm.insertvalue %ptr_llvm, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %6 = llvm.insertvalue %ptr_llvm, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %9 = llvm.insertvalue %c5, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<5xf32>
+  %cast = memref.cast %buffer : memref<5xf32> to memref<?xf32>
+
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.atomic_rmw"(%{{.*}}, %{{.*}}, %{{.*}}) <{kind = 0 : i64}> : (f32, memref<?xf32>, index) -> f32
+  // CHECK-NEXT: ^ out-of-bounds access
+  // CHECK-NEXT: Location: loc({{.*}})
+  %c9 = arith.constant 9 : index
+  func.call @store_dynamic(%cast, %c9) : (memref<?xf32>, index) -> ()
+
+  return
+}
+
diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
new file mode 100644
index 0000000000000..58961ba31d93a
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -test-cf-assert \
+// RUN:     -expand-strided-metadata \
+// RUN:     -lower-affine \
+// RUN:     -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @store_dynamic(%memref: memref<?xf32>, %index: index) {
+  %cst = arith.constant 1.0 : f32
+  memref.store %cst, %memref[%index] :  memref<?xf32>
+  return
+}
+
+func.func @main() {
+  // Allocate a memref<10xf32>, but disguise it as a memref<5xf32>. This is
+  // necessary because "-test-cf-assert" does not abort the program and we do
+  // not want to segfault when running the test case.
+  %alloc = memref.alloca() : memref<10xf32>
+  %ptr = memref.extract_aligned_pointer_as_index %alloc : memref<10xf32> -> index
+  %ptr_i64 = arith.index_cast %ptr : index to i64
+  %ptr_llvm = llvm.inttoptr %ptr_i64 : i64 to !llvm.ptr
+  %c0 = llvm.mlir.constant(0 : index) : i64
+  %c1 = llvm.mlir.constant(1 : index) : i64
+  %c5 = llvm.mlir.constant(5 : index) : i64
+  %4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %5 = llvm.insertvalue %ptr_llvm, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %6 = llvm.insertvalue %ptr_llvm, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %9 = llvm.insertvalue %c5, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  %buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<5xf32>
+  %cast = memref.cast %buffer : memref<5xf32> to memref<?xf32>
+
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: "memref.store"(%{{.*}}, %{{.*}}, %{{.*}}) : (f32, memref<?xf32>, index) -> ()
+  // CHECK-NEXT: ^ out-of-bounds access
+  // CHECK-NEXT: Location: loc({{.*}})
+  %c9 = arith.constant 9 : index
+  func.call @store_dynamic(%cast, %c9) : (memref<?xf32>, index) -> ()
+
+  return
+}
+

@matthias-springer matthias-springer force-pushed the users/matthias-springer/runtime_verify_memref_assume branch from 7cf452f to d6b573e Compare March 18, 2025 09:19
Base automatically changed from users/matthias-springer/runtime_verify_memref_assume to main March 19, 2025 20:23
@matthias-springer matthias-springer force-pushed the users/matthias-springer/atomic_rmw_verification branch from c37848a to 18f917b Compare March 20, 2025 07:31
@matthias-springer matthias-springer force-pushed the users/matthias-springer/atomic_rmw_verification branch 2 times, most recently from d481688 to 7048efa Compare April 28, 2025 13:17
@matthias-springer matthias-springer force-pushed the users/matthias-springer/atomic_rmw_verification branch from 7048efa to fc034e0 Compare April 28, 2025 13:18
@matthias-springer matthias-springer merged commit 120e940 into main Apr 30, 2025
6 of 11 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/atomic_rmw_verification branch April 30, 2025 11:45
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…#130414)

Implement runtime verification for `memref.atomic_rmw` and
`memref.generic_atomic_rmw`. Also add a missing test for `memref.store`.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…#130414)

Implement runtime verification for `memref.atomic_rmw` and
`memref.generic_atomic_rmw`. Also add a missing test for `memref.store`.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…#130414)

Implement runtime verification for `memref.atomic_rmw` and
`memref.generic_atomic_rmw`. Also add a missing test for `memref.store`.
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
…#130414)

Implement runtime verification for `memref.atomic_rmw` and
`memref.generic_atomic_rmw`. Also add a missing test for `memref.store`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants