diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h index 0dd23ac52ac67..0407bccaae879 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h @@ -214,8 +214,10 @@ class SparseTensorStorageBase { /// * `added` a map from `[0..count)` to last-level coordinates for /// which `filled` is true and `values` contains the assotiated value. /// * `count` the size of `added`. + /// * `expsz` the size of the expanded vector (verification only). #define DECL_EXPINSERT(VNAME, V) \ - virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t); + virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t, \ + uint64_t); MLIR_SPARSETENSOR_FOREVERY_V(DECL_EXPINSERT) #undef DECL_EXPINSERT @@ -426,7 +428,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase { /// Partially specialize expanded insertions based on template types. void expInsert(uint64_t *lvlCoords, V *values, bool *filled, uint64_t *added, - uint64_t count) final { + uint64_t count, uint64_t expsz) final { assert((lvlCoords && values && filled && added) && "Received nullptr"); if (count == 0) return; @@ -435,6 +437,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase { // Restore insertion path for first insert. const uint64_t lastLvl = getLvlRank() - 1; uint64_t c = added[0]; + assert(c <= expsz); assert(filled[c] && "added coordinate is not filled"); lvlCoords[lastLvl] = c; lexInsert(lvlCoords, values[c]); @@ -444,6 +447,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase { for (uint64_t i = 1; i < count; ++i) { assert(c < added[i] && "non-lexicographic insertion"); c = added[i]; + assert(c <= expsz); assert(filled[c] && "added coordinate is not filled"); lvlCoords[lastLvl] = c; insPath(lvlCoords, lastLvl, added[i - 1] + 1, values[c]); diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp index 77861e074e933..199e4205a61a2 100644 --- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp @@ -90,7 +90,7 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT) #define IMPL_EXPINSERT(VNAME, V) \ void SparseTensorStorageBase::expInsert(uint64_t *, V *, bool *, uint64_t *, \ - uint64_t) { \ + uint64_t, uint64_t) { \ FATAL_PIV("expInsert" #VNAME); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT) diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp index 05da8cd79190e..8340fe7dcf925 100644 --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -480,7 +480,8 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT) V *values = MEMREF_GET_PAYLOAD(vref); \ bool *filled = MEMREF_GET_PAYLOAD(fref); \ index_type *added = MEMREF_GET_PAYLOAD(aref); \ - tensor.expInsert(lvlCoords, values, filled, added, count); \ + uint64_t expsz = vref->sizes[0]; \ + tensor.expInsert(lvlCoords, values, filled, added, count, expsz); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT) #undef IMPL_EXPINSERT