Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 66ca6e9

Browse files
committed
Update AF_ASSERT_ARRAYS_[EQ,NEAR] to accept sparse arrays
AF_ASSERT_ARRAY_* now accept sparse arrays and can be compared to dense arrays now
1 parent 529e98b commit 66ca6e9

File tree

4 files changed

+284
-98
lines changed

4 files changed

+284
-98
lines changed

test/arrayfire_test.cpp

Lines changed: 268 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
using af::af_cdouble;
4242
using af::af_cfloat;
43+
using std::vector;
4344

4445
bool operator==(const af_half &lhs, const af_half &rhs) {
4546
return lhs.data_ == rhs.data_;
@@ -1390,6 +1391,116 @@ INSTANTIATE(long long);
13901391
INSTANTIATE(unsigned long long);
13911392
#undef INSTANTIATE
13921393

1394+
template<typename T>
1395+
struct sparseCooValue {
1396+
int row = 0;
1397+
int col = 0;
1398+
T value = 0;
1399+
sparseCooValue(int r, int c, T v) : row(r), col(c), value(v) {}
1400+
};
1401+
1402+
template<typename T>
1403+
void swap(sparseCooValue<T> &lhs, sparseCooValue<T> &rhs) {
1404+
std::swap(lhs.row, rhs.row);
1405+
std::swap(lhs.col, rhs.col);
1406+
std::swap(lhs.value, rhs.value);
1407+
}
1408+
1409+
template<typename T>
1410+
bool operator<(const sparseCooValue<T> &lhs, const sparseCooValue<T> &rhs) {
1411+
if (lhs.row < rhs.row) {
1412+
return true;
1413+
} else if (lhs.row == rhs.row && lhs.col < rhs.col) {
1414+
return true;
1415+
} else {
1416+
return false;
1417+
}
1418+
}
1419+
1420+
template<typename T>
1421+
std::ostream &operator<<(std::ostream &os, const sparseCooValue<T> &val) {
1422+
os << "(" << val.row << ", " << val.col << "): " << val.value;
1423+
return os;
1424+
}
1425+
1426+
template<typename T>
1427+
bool isZero(const sparseCooValue<T> &val) {
1428+
return val.value == 0.;
1429+
}
1430+
1431+
template<typename T>
1432+
vector<sparseCooValue<T>> toCooVector(const af::array &arr) {
1433+
vector<sparseCooValue<T>> out;
1434+
if (arr.issparse()) {
1435+
switch (sparseGetStorage(arr)) {
1436+
case AF_STORAGE_COO: {
1437+
dim_t nnz = sparseGetNNZ(arr);
1438+
vector<int> row(nnz), col(nnz);
1439+
vector<T> values(nnz);
1440+
sparseGetValues(arr).host(values.data());
1441+
sparseGetRowIdx(arr).host(row.data());
1442+
sparseGetColIdx(arr).host(col.data());
1443+
out.reserve(nnz);
1444+
for (int i = 0; i < nnz; i++) {
1445+
out.emplace_back(row[i], col[i], values[i]);
1446+
}
1447+
} break;
1448+
case AF_STORAGE_CSR: {
1449+
dim_t nnz = sparseGetNNZ(arr);
1450+
vector<int> row(arr.dims(0) + 1), col(nnz);
1451+
vector<T> values(nnz);
1452+
sparseGetValues(arr).host(values.data());
1453+
sparseGetRowIdx(arr).host(row.data());
1454+
sparseGetColIdx(arr).host(col.data());
1455+
out.reserve(nnz);
1456+
for (int i = 0; i < row.size() - 1; i++) {
1457+
for (int r = row[i]; r < row[i + 1]; r++) {
1458+
out.emplace_back(i, col[r], values[r]);
1459+
}
1460+
}
1461+
} break;
1462+
case AF_STORAGE_CSC: {
1463+
dim_t nnz = sparseGetNNZ(arr);
1464+
vector<int> row(nnz), col(arr.dims(1) + 1);
1465+
vector<T> values(nnz);
1466+
sparseGetValues(arr).host(values.data());
1467+
sparseGetRowIdx(arr).host(row.data());
1468+
sparseGetColIdx(arr).host(col.data());
1469+
out.reserve(nnz);
1470+
for (int i = 0; i < col.size() - 1; i++) {
1471+
for (int c = col[i]; c < col[i + 1]; c++) {
1472+
out.emplace_back(row[c], i, values[c]);
1473+
}
1474+
}
1475+
} break;
1476+
default: throw std::logic_error("NOT SUPPORTED");
1477+
}
1478+
} else {
1479+
vector<T> values(arr.elements());
1480+
arr.host(values.data());
1481+
int M = arr.dims(0), N = arr.dims(1);
1482+
for (int j = 0; j < N; j++) {
1483+
for (int i = 0; i < M; i++) {
1484+
if (std::fpclassify(real(values[j * M + i])) == FP_ZERO) {
1485+
out.emplace_back(i, j, values[j * M + i]);
1486+
}
1487+
}
1488+
}
1489+
}
1490+
1491+
// Remove zero elements from result to ensure that only non-zero elements
1492+
// are compared
1493+
out.erase(std::remove_if(out.begin(), out.end(), isZero<T>), out.end());
1494+
std::sort(begin(out), end(out));
1495+
return out;
1496+
}
1497+
1498+
template<typename T>
1499+
bool operator==(const sparseCooValue<T> &lhs, sparseCooValue<T> &rhs) {
1500+
return lhs.row == rhs.row && lhs.col == rhs.col &&
1501+
cmp(lhs.value, rhs.value);
1502+
}
1503+
13931504
template<typename T>
13941505
std::string printContext(const std::vector<T> &hGold, std::string goldName,
13951506
const std::vector<T> &hOut, std::string outName,
@@ -1495,13 +1606,100 @@ std::string printContext(const std::vector<T> &hGold, std::string goldName,
14951606
return os.str();
14961607
}
14971608

1609+
template<typename T>
1610+
std::string printContext(const std::vector<sparseCooValue<T>> &hGold,
1611+
std::string goldName,
1612+
const std::vector<sparseCooValue<T>> &hOut,
1613+
std::string outName, af::dim4 arrDims,
1614+
af::dim4 arrStrides, dim_t idx) {
1615+
std::ostringstream os;
1616+
1617+
af::dim4 coords = unravelIdx(idx, arrDims, arrStrides);
1618+
dim_t ctxWidth = 5;
1619+
1620+
// Coordinates that span dim0
1621+
af::dim4 coordsMinBound = coords;
1622+
coordsMinBound[0] = 0;
1623+
af::dim4 coordsMaxBound = coords;
1624+
coordsMaxBound[0] = arrDims[0] - 1;
1625+
1626+
// dim0 positions that can be displayed
1627+
dim_t dim0Start = std::max<dim_t>(0LL, idx - ctxWidth);
1628+
dim_t dim0End = std::min<dim_t>(idx + ctxWidth + 1LL, hGold.size());
1629+
1630+
int setwval = 9;
1631+
// Linearized indices of values in vectors that can be displayed
1632+
dim_t vecStartIdx =
1633+
std::max<dim_t>(ravelIdx(coordsMinBound, arrStrides), idx - ctxWidth);
1634+
os << "Idx: ";
1635+
for (int elem = dim0Start; elem < dim0End; elem++) {
1636+
if (elem == idx) {
1637+
os << std::setw(setwval - 2) << "[" << elem << "]";
1638+
} else {
1639+
os << std::setw(setwval) << elem;
1640+
}
1641+
}
1642+
os << "\nRow: ";
1643+
for (int elem = dim0Start; elem < dim0End; elem++) {
1644+
if (elem == idx) {
1645+
os << std::setw(setwval - 2) << "[" << hGold[elem].row << "]";
1646+
} else {
1647+
os << std::setw(setwval) << hGold[elem].row;
1648+
}
1649+
}
1650+
os << "\n ";
1651+
for (int elem = dim0Start; elem < dim0End; elem++) {
1652+
if (elem == idx) {
1653+
os << std::setw(setwval - 2) << "[" << hOut[elem].row << "]";
1654+
} else {
1655+
os << std::setw(setwval) << hOut[elem].row;
1656+
}
1657+
}
1658+
os << "\nCol: ";
1659+
for (int elem = dim0Start; elem < dim0End; elem++) {
1660+
if (elem == idx) {
1661+
os << std::setw(setwval - 2) << "[" << hGold[elem].col << "]";
1662+
} else {
1663+
os << std::setw(setwval) << hGold[elem].col;
1664+
}
1665+
}
1666+
os << "\n ";
1667+
for (int elem = dim0Start; elem < dim0End; elem++) {
1668+
if (elem == idx) {
1669+
os << std::setw(setwval - 2) << "[" << hOut[elem].col << "]";
1670+
} else {
1671+
os << std::setw(setwval) << hOut[elem].col;
1672+
}
1673+
}
1674+
1675+
os << "\nValue: ";
1676+
for (int elem = dim0Start; elem < dim0End; elem++) {
1677+
if (elem == idx) {
1678+
os << std::setw(setwval - 2) << "[" << hGold[elem].value << "]";
1679+
} else {
1680+
os << std::setw(setwval) << hGold[elem].value;
1681+
}
1682+
}
1683+
os << "\n ";
1684+
for (int elem = dim0Start; elem < dim0End; elem++) {
1685+
if (elem == idx) {
1686+
os << std::setw(setwval - 2) << "[" << hOut[elem].value << "]";
1687+
} else {
1688+
os << std::setw(setwval) << hOut[elem].value;
1689+
}
1690+
}
1691+
1692+
return os.str();
1693+
}
1694+
14981695
template<typename T>
14991696
::testing::AssertionResult elemWiseEq(std::string aName, std::string bName,
15001697
const std::vector<T> &a, af::dim4 aDims,
15011698
const std::vector<T> &b, af::dim4 bDims,
15021699
float maxAbsDiff, IntegerTag) {
15031700
UNUSED(maxAbsDiff);
15041701
typedef typename std::vector<T>::const_iterator iter;
1702+
15051703
std::pair<iter, iter> mismatches =
15061704
std::mismatch(a.begin(), a.end(), b.begin());
15071705
iter bItr = mismatches.second;
@@ -1525,7 +1723,7 @@ struct absMatch {
15251723
absMatch(float diff) : diff_(diff) {}
15261724

15271725
template<typename T>
1528-
bool operator()(T lhs, T rhs) {
1726+
bool operator()(const T &lhs, const T &rhs) const {
15291727
if (diff_ > 0) {
15301728
using half_float::abs;
15311729
using std::abs;
@@ -1537,25 +1735,26 @@ struct absMatch {
15371735
};
15381736

15391737
template<>
1540-
bool absMatch::operator()<af::af_cfloat>(af::af_cfloat lhs, af::af_cfloat rhs) {
1738+
bool absMatch::operator()<af::af_cfloat>(const af::af_cfloat &lhs,
1739+
const af::af_cfloat &rhs) const {
15411740
return af::abs(rhs - lhs) <= diff_;
15421741
}
15431742

15441743
template<>
1545-
bool absMatch::operator()<af::af_cdouble>(af::af_cdouble lhs,
1546-
af::af_cdouble rhs) {
1744+
bool absMatch::operator()<af::af_cdouble>(const af::af_cdouble &lhs,
1745+
const af::af_cdouble &rhs) const {
15471746
return af::abs(rhs - lhs) <= diff_;
15481747
}
15491748

15501749
template<>
1551-
bool absMatch::operator()<std::complex<float>>(std::complex<float> lhs,
1552-
std::complex<float> rhs) {
1750+
bool absMatch::operator()<std::complex<float>>(
1751+
const std::complex<float> &lhs, const std::complex<float> &rhs) const {
15531752
return std::abs(rhs - lhs) <= diff_;
15541753
}
15551754

15561755
template<>
1557-
bool absMatch::operator()<std::complex<double>>(std::complex<double> lhs,
1558-
std::complex<double> rhs) {
1756+
bool absMatch::operator()<std::complex<double>>(
1757+
const std::complex<double> &lhs, const std::complex<double> &rhs) const {
15591758
return std::abs(rhs - lhs) <= diff_;
15601759
}
15611760

@@ -1597,6 +1796,53 @@ ::testing::AssertionResult elemWiseEq(std::string aName, std::string bName,
15971796
}
15981797
}
15991798

1799+
template<typename T>
1800+
::testing::AssertionResult elemWiseEq(std::string aName, std::string bName,
1801+
const std::vector<sparseCooValue<T>> &a,
1802+
af::dim4 aDims,
1803+
const std::vector<sparseCooValue<T>> &b,
1804+
af::dim4 bDims, float maxAbsDiff,
1805+
IntegerTag) {
1806+
return ::testing::AssertionFailure() << "Unsupported sparse type\n";
1807+
}
1808+
template<typename T>
1809+
::testing::AssertionResult elemWiseEq(std::string aName, std::string bName,
1810+
const std::vector<sparseCooValue<T>> &a,
1811+
af::dim4 aDims,
1812+
const std::vector<sparseCooValue<T>> &b,
1813+
af::dim4 bDims, float maxAbsDiff,
1814+
FloatTag) {
1815+
typedef typename std::vector<sparseCooValue<T>>::const_iterator iter;
1816+
// TODO(mark): Modify equality for float
1817+
1818+
const absMatch diff(maxAbsDiff);
1819+
std::pair<iter, iter> mismatches = std::mismatch(
1820+
a.begin(), a.end(), b.begin(),
1821+
[&diff](const sparseCooValue<T> &lhs, const sparseCooValue<T> &rhs) {
1822+
return lhs.row == rhs.row && lhs.col == rhs.col &&
1823+
diff(lhs.value, rhs.value);
1824+
});
1825+
1826+
iter aItr = mismatches.first;
1827+
iter bItr = mismatches.second;
1828+
1829+
if (aItr == a.end()) {
1830+
return ::testing::AssertionSuccess();
1831+
} else {
1832+
dim_t idx = std::distance(b.begin(), bItr);
1833+
af::dim4 coords = unravelIdx(idx, bDims, calcStrides(bDims));
1834+
1835+
af::dim4 aStrides = calcStrides(aDims);
1836+
1837+
::testing::AssertionResult result =
1838+
::testing::AssertionFailure()
1839+
<< "VALUE DIFFERS at " << idx << ":\n"
1840+
<< printContext(a, aName, b, bName, aDims, aStrides, idx);
1841+
1842+
return result;
1843+
}
1844+
}
1845+
16001846
template<typename T>
16011847
::testing::AssertionResult elemWiseEq(std::string aName, std::string bName,
16021848
const af::array &a, const af::array &b,
@@ -1606,13 +1852,21 @@ ::testing::AssertionResult elemWiseEq(std::string aName, std::string bName,
16061852
FloatTag, IntegerTag>::type TagType;
16071853
TagType tag;
16081854

1609-
std::vector<T> hA(static_cast<size_t>(a.elements()));
1610-
a.host(hA.data());
1855+
if (a.issparse() || b.issparse()) {
1856+
vector<sparseCooValue<T>> hA = toCooVector<T>(a);
1857+
vector<sparseCooValue<T>> hB = toCooVector<T>(b);
16111858

1612-
std::vector<T> hB(static_cast<size_t>(b.elements()));
1613-
b.host(hB.data());
1614-
return elemWiseEq<T>(aName, bName, hA, a.dims(), hB, b.dims(), maxAbsDiff,
1615-
tag);
1859+
return elemWiseEq<T>(aName, bName, hA, a.dims(), hB, b.dims(),
1860+
maxAbsDiff, tag);
1861+
} else {
1862+
std::vector<T> hA(static_cast<size_t>(a.elements()));
1863+
a.host(hA.data());
1864+
1865+
std::vector<T> hB(static_cast<size_t>(b.elements()));
1866+
b.host(hB.data());
1867+
return elemWiseEq<T>(aName, bName, hA, a.dims(), hB, b.dims(),
1868+
maxAbsDiff, tag);
1869+
}
16161870
}
16171871

16181872
template<typename T>

0 commit comments

Comments
 (0)