-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][affine] Fix a crash when cast incompatible type #145162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This PR fixes a crash in `getSemiAffineExprFromFlatForm` when localExpr is not `AffineBinaryOpExpr`.
@llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis PR fixes a crash in Full diff: https://github.com/llvm/llvm-project/pull/145162.diff 2 Files Affected:
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index c8d9761511bec..cc81f9d19aca7 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1174,11 +1174,15 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
// the indices in `coefficients` map, and affine expression corresponding to
// in indices in `indexToExprMap` map.
for (const auto &it : llvm::enumerate(localExprs)) {
- AffineExpr expr = it.value();
if (flatExprs[numDims + numSymbols + it.index()] == 0)
continue;
- AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
- AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS();
+ AffineExpr expr = it.value();
+ auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!binaryExpr)
+ continue;
+
+ AffineExpr lhs = binaryExpr.getLHS();
+ AffineExpr rhs = binaryExpr.getRHS();
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
(isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
isa<AffineConstantExpr>(rhs)))) {
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index e4a8512b002ee..6f2737a982752 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -592,3 +592,19 @@ func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> in
//CHECK: affine.apply #[[$MAP]]()[%{{.*}}, %{{.*}}, %{{.*}}]
return %a : index
}
+
+// -----
+
+// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv_const
+func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?xf32>) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c13 = arith.constant 13 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %a = affine.apply affine_map<()[s0, s1, s2] -> (s0 floordiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
+ %b = affine.apply affine_map<()[s0, s1, s2] -> (s0 ceildiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
+ // CHECK: %[[C6:.*]] = arith.constant 6 : index
+ // CHECK-NEXT: %[[C7:.*]] = arith.constant 7 : index
+ // CHECK-NEXT: return %[[C6]], %[[C7]]
+ return %a, %b : index, index
+}
|
@llvm/pr-subscribers-mlir-core Author: Longsheng Mou (CoTinker) ChangesThis PR fixes a crash in Full diff: https://github.com/llvm/llvm-project/pull/145162.diff 2 Files Affected:
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index c8d9761511bec..cc81f9d19aca7 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1174,11 +1174,15 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
// the indices in `coefficients` map, and affine expression corresponding to
// in indices in `indexToExprMap` map.
for (const auto &it : llvm::enumerate(localExprs)) {
- AffineExpr expr = it.value();
if (flatExprs[numDims + numSymbols + it.index()] == 0)
continue;
- AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
- AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS();
+ AffineExpr expr = it.value();
+ auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!binaryExpr)
+ continue;
+
+ AffineExpr lhs = binaryExpr.getLHS();
+ AffineExpr rhs = binaryExpr.getRHS();
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
(isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
isa<AffineConstantExpr>(rhs)))) {
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index e4a8512b002ee..6f2737a982752 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -592,3 +592,19 @@ func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> in
//CHECK: affine.apply #[[$MAP]]()[%{{.*}}, %{{.*}}, %{{.*}}]
return %a : index
}
+
+// -----
+
+// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv_const
+func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?xf32>) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c13 = arith.constant 13 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %a = affine.apply affine_map<()[s0, s1, s2] -> (s0 floordiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
+ %b = affine.apply affine_map<()[s0, s1, s2] -> (s0 ceildiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
+ // CHECK: %[[C6:.*]] = arith.constant 6 : index
+ // CHECK-NEXT: %[[C7:.*]] = arith.constant 7 : index
+ // CHECK-NEXT: return %[[C6]], %[[C7]]
+ return %a, %b : index, index
+}
|
@llvm/pr-subscribers-mlir-affine Author: Longsheng Mou (CoTinker) ChangesThis PR fixes a crash in Full diff: https://github.com/llvm/llvm-project/pull/145162.diff 2 Files Affected:
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index c8d9761511bec..cc81f9d19aca7 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1174,11 +1174,15 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
// the indices in `coefficients` map, and affine expression corresponding to
// in indices in `indexToExprMap` map.
for (const auto &it : llvm::enumerate(localExprs)) {
- AffineExpr expr = it.value();
if (flatExprs[numDims + numSymbols + it.index()] == 0)
continue;
- AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
- AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS();
+ AffineExpr expr = it.value();
+ auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!binaryExpr)
+ continue;
+
+ AffineExpr lhs = binaryExpr.getLHS();
+ AffineExpr rhs = binaryExpr.getRHS();
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
(isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
isa<AffineConstantExpr>(rhs)))) {
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index e4a8512b002ee..6f2737a982752 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -592,3 +592,19 @@ func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> in
//CHECK: affine.apply #[[$MAP]]()[%{{.*}}, %{{.*}}, %{{.*}}]
return %a : index
}
+
+// -----
+
+// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv_const
+func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?xf32>) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c13 = arith.constant 13 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %a = affine.apply affine_map<()[s0, s1, s2] -> (s0 floordiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
+ %b = affine.apply affine_map<()[s0, s1, s2] -> (s0 ceildiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
+ // CHECK: %[[C6:.*]] = arith.constant 6 : index
+ // CHECK-NEXT: %[[C7:.*]] = arith.constant 7 : index
+ // CHECK-NEXT: return %[[C6]], %[[C7]]
+ return %a, %b : index, index
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM if CI green.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks
This PR fixes a crash in
getSemiAffineExprFromFlatForm
when localExpr is notAffineBinaryOpExpr
. Fixes #144091.