-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][memref] Fix computeCollapsedLayoutMap for contiguous dynamic dim #136485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir Author: Maya Amrami (amrami) ChangesFull diff: https://github.com/llvm/llvm-project/pull/136485.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 6f10a31c15626..45ac4c9d5117e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -23,6 +23,7 @@
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
+#include <algorithm>
using namespace mlir;
using namespace mlir::memref;
@@ -2401,11 +2402,19 @@ computeCollapsedLayoutMap(MemRefType srcType,
if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
resultStrides.push_back(srcStrides[ref.back()]);
} else {
- // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
- // the corresponding stride may have to be skipped. (See above comment.)
- // Therefore, the result stride cannot be statically determined and must
- // be dynamic.
- resultStrides.push_back(ShapedType::kDynamic);
+ bool contiguousSrcDim = srcStrides[ref.back()] == 1;
+ bool dynamicSizeIsPreserved =
+ std::all_of(ref.begin(), ref.end() - 1,
+ [srcShape](int64_t dim) { return srcShape[dim] == 1; });
+ if (contiguousSrcDim && dynamicSizeIsPreserved)
+ resultStrides.push_back(1);
+ else {
+ // Dynamically-sized dims may turn out to be dims of size 1 at runtime,
+ // so the corresponding stride may have to be skipped. (See above
+ // comment.) Therefore, the result stride cannot be statically
+ // determined and must be dynamic.
+ resultStrides.push_back(ShapedType::kDynamic);
+ }
}
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 34fc4775924e7..a8fcd91fba097 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -502,6 +502,14 @@ func.func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {
// -----
+func.func @collapse_shape_infer_stride_of_dynamic_dim(%arg0: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>, %dim : index) -> (memref<?xsi32, strided<[?]>, 1>) {
+ // expected-error @+1 {{expected collapsed type to be 'memref<?xsi32, strided<[1]>, 1>' but found 'memref<?xsi32, strided<[?]>, 1>'}}
+ %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into memref<?xsi32, strided<[?]>, 1>
+ return %collapse_shape : memref<?xsi32, strided<[?]>, 1>
+}
+
+// -----
+
func.func @expand_shape_illegal_static_memref
(%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
// expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}}
|
@llvm/pr-subscribers-mlir-memref Author: Maya Amrami (amrami) ChangesFull diff: https://github.com/llvm/llvm-project/pull/136485.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 6f10a31c15626..45ac4c9d5117e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -23,6 +23,7 @@
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
+#include <algorithm>
using namespace mlir;
using namespace mlir::memref;
@@ -2401,11 +2402,19 @@ computeCollapsedLayoutMap(MemRefType srcType,
if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
resultStrides.push_back(srcStrides[ref.back()]);
} else {
- // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
- // the corresponding stride may have to be skipped. (See above comment.)
- // Therefore, the result stride cannot be statically determined and must
- // be dynamic.
- resultStrides.push_back(ShapedType::kDynamic);
+ bool contiguousSrcDim = srcStrides[ref.back()] == 1;
+ bool dynamicSizeIsPreserved =
+ std::all_of(ref.begin(), ref.end() - 1,
+ [srcShape](int64_t dim) { return srcShape[dim] == 1; });
+ if (contiguousSrcDim && dynamicSizeIsPreserved)
+ resultStrides.push_back(1);
+ else {
+ // Dynamically-sized dims may turn out to be dims of size 1 at runtime,
+ // so the corresponding stride may have to be skipped. (See above
+ // comment.) Therefore, the result stride cannot be statically
+ // determined and must be dynamic.
+ resultStrides.push_back(ShapedType::kDynamic);
+ }
}
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 34fc4775924e7..a8fcd91fba097 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -502,6 +502,14 @@ func.func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {
// -----
+func.func @collapse_shape_infer_stride_of_dynamic_dim(%arg0: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>, %dim : index) -> (memref<?xsi32, strided<[?]>, 1>) {
+ // expected-error @+1 {{expected collapsed type to be 'memref<?xsi32, strided<[1]>, 1>' but found 'memref<?xsi32, strided<[?]>, 1>'}}
+ %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into memref<?xsi32, strided<[?]>, 1>
+ return %collapse_shape : memref<?xsi32, strided<[?]>, 1>
+}
+
+// -----
+
func.func @expand_shape_illegal_static_memref
(%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
// expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}}
|
ping |
ping |
@@ -502,6 +502,14 @@ func.func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) { | |||
|
|||
// ----- | |||
|
|||
func.func @collapse_shape_infer_stride_of_dynamic_dim(%arg0: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>, %dim : index) -> (memref<?xsi32, strided<[?]>, 1>) { | |||
// expected-error @+1 {{expected collapsed type to be 'memref<?xsi32, strided<[1]>, 1>' but found 'memref<?xsi32, strided<[?]>, 1>'}} | |||
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into memref<?xsi32, strided<[?]>, 1> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add another example where not all of the static source dimensions are 1?
resultStrides.push_back(ShapedType::kDynamic); | ||
bool contiguousSrcDim = srcStrides[ref.back()] == 1; | ||
bool dynamicSizeIsPreserved = | ||
std::all_of(ref.begin(), ref.end() - 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it ref.begin() -> ref.end() - 1
? All dimensions except for one must be 1
, right? In that case, it does not matter where non-unit dimension is?
No description provided.