From 91a9c895e27cf566558a8980b453d9e4883c3fcd Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Sun, 7 Nov 2021 01:41:52 -0400 Subject: [PATCH] Refactor Co-authored-by: "Chris Elrod" --- include/math.hpp | 144 ++++++++++++++++++++++++++------------------- src/main.cpp | 2 +- test/perm_test.cpp | 26 ++++---- 3 files changed, 96 insertions(+), 76 deletions(-) diff --git a/include/math.hpp b/include/math.hpp index f0729c3a6..7ee91399f 100644 --- a/include/math.hpp +++ b/include/math.hpp @@ -1,13 +1,17 @@ +// We'll follow Julia style, so anything that's not a constructor, destructor, +// nor an operator will be outside of the struct/class. #include #include #include #include +#include + const size_t MAX_NUM_LOOPS = 16; const size_t MAX_PROGRAM_VARIABLES = 32; typedef int32_t Int; -template struct Vector { +template struct Vector { static constexpr size_t D = !M; T* ptr; @@ -15,19 +19,21 @@ template struct Vector { Vector(T *ptr, const std::array dims) : ptr(ptr), dims(dims) {}; - size_t getSize() {return (M == 0) ? dims[0] : M;} - T &operator()(size_t i) { return ptr[i]; } +}; - void show() { - for (size_t i = 0; i < getSize(); i++) { - std::printf("%17d", (*this)(i)); - } - std::printf("\n"); +template +size_t length(Vector v) {return (M == 0) ? v.dims[0] : M;} + +template +void show(Vector v) { + for (size_t i = 0; i < length(v); i++) { + std::printf("%17d", v(i)); } -}; + std::printf("\n"); +} -template struct Matrix { +template struct Matrix { static constexpr size_t D = (!M + !N); T *ptr; @@ -35,74 +41,86 @@ template struct Matrix { Matrix(T *ptr, const std::array dims) : ptr(ptr), dims(dims){}; - size_t getSize(size_t i) { - if (i == 0) { - return (M != 0) ? M : dims[0]; - } else { - return (N != 0) ? N : dims[D - 1]; - } - } - - T &operator()(size_t i, size_t j) { return ptr[i + j * getSize(0)]; } + T &operator()(size_t i, size_t j) { return ptr[i + j * size((*this), 0)]; } +}; - void show() { - for (size_t i = 0; i < getSize(0); i++) { - for (size_t j = 0; j < getSize(1); j++) { - std::printf("%17d", (*this)(i, j)); - } - std::printf("\n"); - } +template +size_t size(Matrix A, size_t i) { + static constexpr size_t D = (!M + !N); + if (i == 0) { + return (M != 0) ? M : A.dims[0]; + } else { + return (N != 0) ? N : A.dims[D - 1]; } +} - Vector getCol(size_t i) { - if (M == 0) { - return Vector(ptr + i * getSize(0), std::array{{getSize(0)}}); - } else { - return Vector(ptr + i * getSize(0), std::array{{}}); +template +Vector getCol(Matrix A, size_t i) { + return Vector(A.ptr + i * size(A, 0), std::array{{}}); +} +template +Vector getCol(Matrix A, size_t i) { + auto s1 = size(A, 0); + return Vector(A.ptr + i * s1, std::array{{s1}}); +} + +template +void show(Matrix A) { + for (size_t i = 0; i < size(A, 0); i++) { + for (size_t j = 0; j < size(A, 1); j++) { + std::printf("%17d", A(i, j)); } + std::printf("\n"); } -}; +} template size_t getNLoops(T x) { return x.data.dims[0]; } +typedef Matrix PermutationData; +typedef Vector PermutationVector; struct Permutation { - typedef Matrix M; - M data; + PermutationData data; Permutation(Int *ptr, size_t nloops) - : data(M(ptr, std::array{{nloops}})) { + : data(PermutationData(ptr, std::array{{nloops}})) { assert(nloops <= MAX_NUM_LOOPS); }; - Int &operator()(size_t i, size_t j) { return data(i, j); } + Int &operator()(size_t i) { return data(i, 0); } +}; - Permutation init() { - auto p = (*this); - Int numloops = getNLoops(p); - for (Int n = 0; n < numloops; n++) { - p(n, 0) = n; - p(n, 1) = n; - } - return p; - } +PermutationVector inv(Permutation p) { + return getCol(p.data, 1); +} - void show() { - auto perm = (*this); - auto numloop = getNLoops(perm); - std::printf("perm: <"); - for (Int j = 0; j < numloop - 1; j++) - std::printf("%d ", perm(j, 0)); - std::printf("%d>\n", perm(numloop - 1, 0)); +Int& inv(Permutation p, size_t j) { + return p.data(j, 1); +} + +Permutation init(Permutation p) { + Int numloops = getNLoops(p); + for (Int n = 0; n < numloops; n++) { + p(n) = n; + inv(p, n) = n; } -}; + return p; +} + +void show(Permutation p) { + auto numloop = getNLoops(p); + std::printf("perm: <"); + for (Int j = 0; j < numloop - 1; j++) + std::printf("%d ", p(j)); + std::printf("%d>\n", p(numloop - 1)); +} void swap(Permutation p, Int i, Int j) { - Int xi = p(i, 0); - Int xj = p(j, 0); - p(i, 0) = xj; - p(j, 0) = xi; - p(xj, 1) = i; - p(xi, 1) = j; + Int xi = p(i); + Int xj = p(j); + p(i) = xj; + p(j) = xi; + inv(p, xj) = i; + inv(p, xi) = j; } struct PermutationSubset { @@ -304,6 +322,7 @@ function compatible(l1::TriangularLoopNest, l2::TriangularLoopNest, perm1::Permu end */ +/* bool compatible(RectangularLoopNest l1, RectangularLoopNest l2, Permutation perm1, Permutation perm2, Int i1, Int i2){ i1 = perm1(i1, 0); i2 = perm2(i2, 0); @@ -336,13 +355,13 @@ struct TriangularLoopNest { bool otherwiseIndependent(TrictM A, Int j, Int i) { for (auto k = 0; k < j; k++) if (!A(k, j)) return false; // A is symmetric - for (auto k = j+1; k < A.getSize(0); k++) + for (auto k = j+1; k < size(A, 0); k++) if (!((k == i) | (A(k, j) == 0))) return false; return true; } bool zeroMinimum(TrictM A, Int j, Int _j, Permutation perm) { - for (auto k = j+1; k < A.getSize(0); k++) { + for (auto k = j+1; k < size(A, 0); k++) { auto j_lower_bounded_by_k = A(k, j) < 0; if (!j_lower_bounded_by_k) continue; auto _k = perm(k, 1); @@ -370,13 +389,14 @@ bool zeroInnerIterationsAtMaximum(TrictM A, RektM ub, RectangularLoopNest r, Int if (Aij >= 0) continue; if (upperboundDominates(ub, i, r.data, j)) return true; } - for (auto j = i+1; j < A.getSize(0); j++) { + for (auto j = i+1; j < size(A, 0); j++) { auto Aij = A(i, j); if (Aij <= 0) continue; if (upperboundDominates(ub, i, r.data, j)) return true; } return false; } +*/ /* diff --git a/src/main.cpp b/src/main.cpp index bd46b88e4..3e4ef085f 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -17,7 +17,7 @@ int main() { } fmt::print("dims: {}, {}\n", M.dims[0], M.dims[1]); fmt::print("A[0, 1]: {}\n", M(0, 1)); - M.show(); + show(M); // delete[] f; return 0; } diff --git a/test/perm_test.cpp b/test/perm_test.cpp index 12e58a743..484cb04d0 100644 --- a/test/perm_test.cpp +++ b/test/perm_test.cpp @@ -7,7 +7,7 @@ const size_t numloop = 5; Int x[2 * numloop + 2]; -auto p = Permutation(x, numloop).init(); +auto p = init(Permutation(x, numloop)); std::set> s; std::vector tperm(numloop); @@ -15,14 +15,14 @@ void recursive_iterator(Permutation p, Int lv = 0, Int num_exterior = 0) { Int nloops = getNLoops(p); assert(lv < 6); if ((lv + 1) == nloops) { - for (Int j = 0; j < numloop; j++) tperm[j] = p(j, 0); - p.show(); + for (Int j = 0; j < numloop; j++) tperm[j] = p(j); + show(p); std::vector perm = tperm; std::sort(tperm.begin(), tperm.end()); for (Int j = 0; j < numloop; j++) { // Test if there's a bijection. - auto ip = p(j, 1); - EXPECT_EQ(p(ip, 0), j); + auto ip = inv(p, j); + EXPECT_EQ(p(ip), j); EXPECT_EQ(tperm[j], j); } // This makes the test fail DELIBERATELY. We didn't know @@ -49,14 +49,14 @@ void recursive_iterator_2(PermutationLevelIterator pli, Int lv = 0, Int num_exte Int nloops = getNLoops(p); assert(lv < 6); if ((lv + 1) == nloops) { - for (Int j = 0; j < numloop; j++) tperm[j] = p(j, 0); - p.show(); + for (Int j = 0; j < numloop; j++) tperm[j] = p(j); + show(p); std::vector perm = tperm; std::sort(tperm.begin(), tperm.end()); for (Int j = 0; j < numloop; j++) { // Test if there's a bijection. - auto ip = p(j, 1); - EXPECT_EQ(p(ip, 0), j); + auto ip = inv(p, j); + EXPECT_EQ(p(ip), j); EXPECT_EQ(tperm[j], j); } // This makes the test fail DELIBERATELY. We didn't know @@ -79,7 +79,7 @@ void recursive_iterator_2(PermutationLevelIterator pli, Int lv = 0, Int num_exte TEST(PermTest, BasicAssertions) { s.clear(); - p.init(); + init(p); recursive_iterator(p); // Test the number of permutations == numloop! @@ -87,19 +87,19 @@ TEST(PermTest, BasicAssertions) { std::printf("[Nice 1] Phew, we are done with PermTest!\n"); s.clear(); - p.init(); + init(p); recursive_iterator_2(PermutationLevelIterator(p, 0, 0)); EXPECT_EQ(s.size(), 5 * 4 * 3 * 2 * 1); std::printf("[Nice 2] Phew, we are done with PermTest!\n"); s.clear(); - p.init(); + init(p); recursive_iterator(p, 0, 3); EXPECT_EQ(s.size(), 3 * 2 * 1 * (2 * 1)); std::printf("[Nice 3] Phew, we are done with PermTest!\n"); s.clear(); - p.init(); + init(p); recursive_iterator_2(PermutationLevelIterator(p, 0, 3)); // Test the number of permutations == numloop! EXPECT_EQ(s.size(), 3 * 2 * 1 * (2 * 1));