Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[mlir][vector] linearize vector.insert_strided_slice (flatten to vector.shuffle) #138725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

newling
Copy link
Contributor

@newling newling commented May 6, 2025

Extends the set of vector operations that we can linearize to include vector.insert_strided_slice. The new pattern reuses the ideas from vector.extract_strided_slice linearization.

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

Extends the set of vector operations that we can linearize to include vector.insert_strided_slice. The new pattern reuses the ideas from vector.extract_strided_slice linearization.


Full diff: https://github.com/llvm/llvm-project/pull/138725.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+196-87)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+47-7)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..8ffb3b0cb2c42 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -109,17 +109,103 @@ struct LinearizeVectorizable final
   }
 };
 
-/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
-/// on a linearized vector.
-/// Following,
+template <typename TOp>
+static bool stridesAllOne(TOp op) {
+  static_assert(
+      std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
+          std::is_same_v<TOp, vector::InsertStridedSliceOp>,
+      "expected vector.extract_strided_slice or vector.insert_strided_slice");
+  ArrayAttr strides = op.getStrides();
+  return llvm::all_of(
+      strides, [](auto stride) { return isConstantIntValue(stride, 1); });
+}
+
+/// Convert an array of attributes into a vector of integers, if possible.
+static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
+  if (!attrs)
+    return failure();
+  SmallVector<int64_t> ints;
+  ints.reserve(attrs.size());
+  for (auto attr : attrs) {
+    if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+      ints.push_back(intAttr.getInt());
+    } else {
+      return failure();
+    }
+  }
+  return ints;
+}
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumeratates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+SmallVector<int64_t> static getFlattenedStridedSliceIndices(
+    ArrayRef<int64_t> small, ArrayRef<int64_t> large,
+    ArrayRef<int64_t> offsets) {
+
+  // Example of alignment between, `large`, `small` and `offsets`:
+  //    large  =  4, 5, 6, 7, 8
+  //    small  =     1, 6, 7, 8
+  //  offsets  =  2, 3, 0
+  //
+  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
+  assert(large.size() >= small.size());
+  assert(large.size() >= offsets.size());
+  unsigned delta = large.size() - small.size();
+  unsigned nOffsets = offsets.size();
+  auto getSmall = [&](int64_t i) { return i >= delta ? small[i - delta] : 1; };
+  auto getOffset = [&](int64_t i) { return i < nOffsets ? offsets[i] : 0; };
+
+  // Using 2 vectors of indices, at each iteration populate the updated set of
+  // indices based on the old set of indices, and the size of the small vector
+  // in the current iteration.
+  SmallVector<int64_t> indices{0};
+  SmallVector<int64_t> nextIndices;
+  int64_t stride = 1;
+  for (int i = large.size() - 1; i >= 0; --i) {
+    auto currentSize = indices.size();
+    auto smallSize = getSmall(i);
+    auto nextSize = currentSize * smallSize;
+    nextIndices.resize(nextSize);
+    int64_t *base = nextIndices.begin();
+    int64_t offset = getOffset(i) * stride;
+    for (int j = 0; j < smallSize; ++j) {
+      for (uint64_t k = 0; k < currentSize; ++k) {
+        base[k] = indices[k] + offset;
+      }
+      offset += stride;
+      base += currentSize;
+    }
+    stride *= large[i];
+    std::swap(indices, nextIndices);
+    nextIndices.clear();
+  }
+  return indices;
+}
+
+/// This pattern converts a vector.extract_strided_slice operation into a
+/// vector.shuffle operation that has a rank-1 (linearized) operand and result.
+///
+/// For example, the following:
+///
+/// ```
 ///   vector.extract_strided_slice %source
 ///         { offsets = [..], strides = [..], sizes = [..] }
+/// ```
+///
 /// is converted to :
+/// ```
 ///   %source_1d = vector.shape_cast %source
-///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-///   %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the offsets and sizes of the
-/// extraction.
+///   %out_1d    = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+///   %out_nd    = vector.shape_cast %out_1d
+/// ```
+///
+/// `shuffle_indices_1d` is computed using the offsets and sizes of the original
+/// vector.extract_strided_slice operation.
 struct LinearizeVectorExtractStridedSlice final
     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -129,88 +215,110 @@ struct LinearizeVectorExtractStridedSlice final
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
-  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
+  matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
+                  OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType dstType =
-        getTypeConverter()->convertType<VectorType>(extractOp.getType());
-    assert(dstType && "vector type destination expected.");
-    if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
-      return rewriter.notifyMatchFailure(extractOp,
-                                         "scalable vectors are not supported.");
 
-    ArrayAttr offsets = extractOp.getOffsets();
-    ArrayAttr sizes = extractOp.getSizes();
-    ArrayAttr strides = extractOp.getStrides();
-    if (!isConstantIntValue(strides[0], 1))
-      return rewriter.notifyMatchFailure(
-          extractOp, "Strided slice with stride != 1 is not supported.");
-    Value srcVector = adaptor.getVector();
-    // If kD offsets are specified for nD source vector (n > k), the granularity
-    // of the extraction is greater than 1. In this case last (n-k) dimensions
-    // form the extraction granularity.
-    // Example :
-    //  vector.extract_strided_slice %src {
-    //      offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
-    //      vector<4x8x8xf32> to vector<2x2x8xf32>
-    // Here, extraction granularity is 8.
-    int64_t extractGranularitySize = 1;
-    int64_t nD = extractOp.getSourceVectorType().getRank();
-    int64_t kD = (int64_t)offsets.size();
-    int64_t k = kD;
-    while (k < nD) {
-      extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
-      ++k;
+    VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
+        extractStridedSliceOp.getType());
+    assert(flatOutputType && "vector type expected");
+
+    if (!stridesAllOne(extractStridedSliceOp)) {
+      return rewriter.notifyMatchFailure(extractStridedSliceOp,
+                                         "strides other than 1 not supported");
     }
-    // Get total number of extracted slices.
-    int64_t nExtractedSlices = 1;
-    for (Attribute size : sizes) {
-      nExtractedSlices *= cast<IntegerAttr>(size).getInt();
+
+    FailureOr<SmallVector<int64_t>> offsets =
+        intsFromArrayAttr(extractStridedSliceOp.getOffsets());
+    if (failed(offsets)) {
+      return rewriter.notifyMatchFailure(extractStridedSliceOp,
+                                         "failed to get integer offsets");
     }
-    // Compute the strides of the source vector considering first k dimensions.
-    llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
-    for (int i = kD - 2; i >= 0; --i) {
-      sourceStrides[i] = sourceStrides[i + 1] *
-                         extractOp.getSourceVectorType().getShape()[i + 1];
+
+    ArrayRef<int64_t> inputShape =
+        extractStridedSliceOp.getSourceVectorType().getShape();
+
+    ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
+
+    SmallVector<int64_t> indices = getFlattenedStridedSliceIndices(
+        outputShape, inputShape, offsets.value());
+
+    Value srcVector = adaptor.getVector();
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+        extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
+    return success();
+  }
+};
+
+/// This pattern converts a vector.insert_strided_slice operation into a
+/// vector.shuffle operation that has rank-1 (linearized) operands and result.
+///
+/// For example, the following:
+/// ```
+///  %0 = vector.insert_strided_slice %to_store, %into
+///             {offsets = [1, 0, 0, 0], strides = [1, 1]}
+///                  : vector<2x2xi8> into vector<2x1x3x2xi8>
+/// ```
+///
+/// is converted to
+/// ```
+///  %to_store_1d
+///           = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
+///  %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
+///  %out_1d  = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
+///  %out_nd  = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
+/// ```
+///
+/// where shuffle_indices_1d in this case is
+///     [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
+///                        ^^^^^^^^^^^^^^
+///                          to_store_1d
+///
+struct LinearizeVectorInsertStridedSlice final
+    : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
+                                    MLIRContext *context,
+                                    PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!stridesAllOne(insertStridedSliceOp)) {
+      return rewriter.notifyMatchFailure(insertStridedSliceOp,
+                                         "strides other than 1 not supported");
     }
-    // Final shuffle indices has nExtractedSlices * extractGranularitySize
-    // elements.
-    llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
-                                          extractGranularitySize);
-    // Compute the strides of the extracted kD vector.
-    llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
-    // Compute extractedStrides.
-    for (int i = kD - 2; i >= 0; --i) {
-      extractedStrides[i] =
-          extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
+
+    VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+
+    VectorType outputType = insertStridedSliceOp.getType();
+    ArrayRef<int64_t> outputShape = outputType.getShape();
+    int64_t nOutputElements = outputType.getNumElements();
+
+    FailureOr<SmallVector<int64_t>> offsets =
+        intsFromArrayAttr(insertStridedSliceOp.getOffsets());
+    if (failed(offsets)) {
+      return rewriter.notifyMatchFailure(insertStridedSliceOp,
+                                         "failed to get integer offsets");
     }
-    // Iterate over all extracted slices from 0 to nExtractedSlices - 1
-    // and compute the multi-dimensional index and the corresponding linearized
-    // index within the source vector.
-    for (int64_t i = 0; i < nExtractedSlices; ++i) {
-      int64_t index = i;
-      // Compute the corresponding multi-dimensional index.
-      llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
-      for (int64_t j = 0; j < kD; ++j) {
-        multiDimIndex[j] = (index / extractedStrides[j]);
-        index -= multiDimIndex[j] * extractedStrides[j];
-      }
-      // Compute the corresponding linearized index in the source vector
-      // i.e. shift the multiDimIndex by the offsets.
-      int64_t linearizedIndex = 0;
-      for (int64_t j = 0; j < kD; ++j) {
-        linearizedIndex +=
-            (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
-            sourceStrides[j];
-      }
-      // Fill the indices array form linearizedIndex to linearizedIndex +
-      // extractGranularitySize.
-      for (int64_t j = 0; j < extractGranularitySize; ++j) {
-        indices[i * extractGranularitySize + j] = linearizedIndex + j;
-      }
+    SmallVector<int64_t> sliceIndices = getFlattenedStridedSliceIndices(
+        inputShape, outputShape, offsets.value());
+
+    SmallVector<int64_t> indices(nOutputElements, 0);
+    std::iota(indices.begin(), indices.end(), 0);
+    for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
+      indices[sliceIndex] = index + nOutputElements;
     }
-    // Perform a shuffle to extract the kD vector.
-    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        extractOp, dstType, srcVector, srcVector, indices);
+
+    Value flatToStore = adaptor.getValueToStore();
+    Value flatDest = adaptor.getDest();
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
+                                                   flatDest.getType(), flatDest,
+                                                   flatToStore, indices);
     return success();
   }
 };
@@ -296,7 +404,7 @@ struct LinearizeVectorExtract final
     // Skip if result is not a vector type
     if (!isa<VectorType>(extractOp.getType()))
       return rewriter.notifyMatchFailure(extractOp,
-                                         "scalar extract is not supported.");
+                                         "scalar extract not supported");
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     assert(dstTy && "expected 1-D vector type");
 
@@ -453,8 +561,8 @@ struct LinearizeVectorSplat final
 static bool isNotLinearizableBecauseScalable(Operation *op) {
 
   bool unsupported =
-      isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
-          op);
+      isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp,
+          vector::ExtractOp, vector::InsertOp>(op);
   if (!unsupported)
     return false;
 
@@ -539,6 +647,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, const ConversionTarget &target,
     RewritePatternSet &patterns) {
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
-               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
-      typeConverter, patterns.getContext());
+               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
+               LinearizeVectorInsertStridedSlice>(typeConverter,
+                                                  patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 01ad1ac48b012..508bce689b14e 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -131,9 +131,9 @@ func.func @test_0d_vector() -> vector<f32> {
 
 // -----
 
-// CHECK-LABEL: test_extract_strided_slice_1
+// CHECK-LABEL: test_extract_strided_slice_2D
 // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
-func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
+func.func @test_extract_strided_slice_2D(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
 
   // CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
   // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
@@ -147,13 +147,13 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
 
 // -----
 
-// CHECK-LABEL:   func.func @test_extract_strided_slice_1_scalable(
+// CHECK-LABEL:   func.func @test_extract_strided_slice_2D_scalable(
 // CHECK-SAME:    %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
-func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+func.func @test_extract_strided_slice_2D_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
 
   // CHECK-NOT: vector.shuffle
   // CHECK-NOT: vector.shape_cast
-  // CHECK: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
+  // CHECK: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] 
   %0 = vector.extract_strided_slice %arg0 { sizes = [2, 8], strides = [1, 1], offsets = [1, 0] } : vector<4x[8]xf32> to vector<2x[8]xf32>
 
   // CHECK: return %[[RES]] : vector<2x[8]xf32>
@@ -162,9 +162,9 @@ func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> ve
 
 // -----
 
-// CHECK-LABEL: test_extract_strided_slice_2
+// CHECK-LABEL: test_extract_strided_slice_3D
 // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
-func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
+func.func @test_extract_strided_slice_3D(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
 
   // CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
   // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
@@ -178,6 +178,45 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
 
 // -----
 
+// Test of insert_strided_slice -> shuffle.
+// This is a contiguous insertion of 4 elements at offset 6 into a vector of 12 elements. 
+// CHECK-LABEL: insert_strided_slice_2D_into_4D
+func.func @insert_strided_slice_2D_into_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1x3x2xi8>) -> vector<2x1x3x2xi8> {
+
+//   CHECK-DAG:    %[[ARG0:.*]] = vector.shape_cast {{.*}}  to vector<4xi8>
+//   CHECK-DAG:    %[[ARG1:.*]] = vector.shape_cast {{.*}}  to vector<12xi8>
+//       CHECK:    vector.shuffle %[[ARG1]], %[[ARG0]]
+//  CHECK-SAME:      [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11] : vector<12xi8>, vector<4xi8>
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 0, 0, 0], strides = [1, 1]} : vector<2x2xi8> into vector<2x1x3x2xi8>
+
+//       CHECK:    %[[RES:.*]] = vector.shape_cast {{.*}} to vector<2x1x3x2xi8>
+//       CHECK:    return %[[RES]] : vector<2x1x3x2xi8>
+  return %0 : vector<2x1x3x2xi8>
+}
+
+// -----
+
+// Test of insert_strided_slice -> shuffle. 
+// [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15]], [[16, 17]]]
+//                                         ^         ^
+//                                         |         |
+//                          where the 2 elements are inserted into the 3x3x2 vector
+// CHECK-LABEL: insert_strided_slice_3D
+func.func @insert_strided_slice_3D(%arg0 : vector<1x2x1xi8>, %arg2 : vector<3x3x2xi8>) -> vector<3x3x2xi8> {
+
+//   CHECK-DAG:     %[[ARG0:.*]] = vector.shape_cast {{.*}}  to vector<2xi8>
+//   CHECK-DAG:     %[[ARG1:.*]] = vector.shape_cast {{.*}}  to vector<18xi8>
+//       CHECK:     vector.shuffle %[[ARG1]], %[[ARG0]]
+//  CHECK-SAME:       [0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 10, 19, 12, 13, 14, 15, 16, 17] : vector<18xi8>, vector<2xi8>
+  %0 = vector.insert_strided_slice %arg0, %arg2 {offsets = [1, 1, 1], sizes = [1, 2, 1], strides = [1, 1, 1]} : vector<1x2x1xi8> into vector<3x3x2xi8>
+
+//       CHECK:     %[[RES:.*]] = vector.shape_cast {{.*}} to vector<3x3x2xi8>
+//       CHECK:     return %[[RES]] : vector<3x3x2xi8>
+  return %0 : vector<3x3x2xi8>
+}
+
+// -----
+
 // CHECK-LABEL: test_vector_shuffle
 // CHECK-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
 func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
@@ -345,3 +384,4 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
   %0 = vector.splat %arg0 : vector<4x[2]xi32>
   return %0 : vector<4x[2]xi32>
 }
+

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Some minor comments

// CHECK: %[[RES:.*]] = vector.shape_cast {{.*}} to vector<3x3x2xi8>
// CHECK: return %[[RES]] : vector<3x3x2xi8>
return %0 : vector<3x3x2xi8>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add negative tests for strides != 1 and scalable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a negative test for scalable, but the strides !=1 case is not reachable because of the strided slice op verifiers confirm that stride is 1:

failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,

I've changed to asserts, to indicate that we expect it to be impossible currently to reach this point with strides>1 (happy to change this back to emitOpError though if that's preferred).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest that we fail the match gracefully just in case strides != 1 are supported in the future...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I prefer that. I can't test it, but I that's unavoidable. Until strides != 1 is supported (or entirely removed :-))

SmallVector<int64_t> sliceIndices = getFlattenedStridedSliceIndices(
inputShape, outputShape, offsets.value());

SmallVector<int64_t> indices(nOutputElements, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't initialize indices if we are overriding them all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, because it's not like std::vector. Forgot that, thanks!

/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
/// positions written to are (1,3) and (1,4), which have linearized indices 8
/// and 9. So [8,9] is returned.
SmallVector<int64_t> static getFlattenedStridedSliceIndices(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the name somewhat suggested to me that all the indices were returned. Could we make it more explicit? Maybe include "InsertionIndices" in the name?

//
// `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
assert(large.size() >= small.size());
assert(large.size() >= offsets.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add message to assert?

// CHECK-DAG: %[[ARG1:.*]] = vector.shape_cast {{.*}} to vector<18xi8>
// CHECK: vector.shuffle %[[ARG1]], %[[ARG0]]
// CHECK-SAME: [0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 10, 19, 12, 13, 14, 15, 16, 17] : vector<18xi8>, vector<2xi8>
%0 = vector.insert_strided_slice %arg0, %arg2 {offsets = [1, 1, 1], sizes = [1, 2, 1], strides = [1, 1, 1]} : vector<1x2x1xi8> into vector<3x3x2xi8>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use offsets != 0 or 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added new test

stride *= large[i];
std::swap(indices, nextIndices);
nextIndices.clear();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't have time to review this function.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Please, address the remaining comments before landing

}
stride *= large[i];
std::swap(indices, nextIndices);
nextIndices.clear();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder if it's better to resize + clear or just move nextIndices decl inside the loop... The latter at least would make clearer that nextIndices doesn't have valid information at the beginning of the iteration... It wasn't obvious to me with the resize and clear approach...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants