40
40
41
41
using af::af_cdouble;
42
42
using af::af_cfloat;
43
+ using std::vector;
43
44
44
45
bool operator ==(const af_half &lhs, const af_half &rhs) {
45
46
return lhs.data_ == rhs.data_ ;
@@ -1390,6 +1391,116 @@ INSTANTIATE(long long);
1390
1391
INSTANTIATE (unsigned long long );
1391
1392
#undef INSTANTIATE
1392
1393
1394
+ template <typename T>
1395
+ struct sparseCooValue {
1396
+ int row = 0 ;
1397
+ int col = 0 ;
1398
+ T value = 0 ;
1399
+ sparseCooValue (int r, int c, T v) : row(r), col(c), value(v) {}
1400
+ };
1401
+
1402
+ template <typename T>
1403
+ void swap (sparseCooValue<T> &lhs, sparseCooValue<T> &rhs) {
1404
+ std::swap (lhs.row , rhs.row );
1405
+ std::swap (lhs.col , rhs.col );
1406
+ std::swap (lhs.value , rhs.value );
1407
+ }
1408
+
1409
+ template <typename T>
1410
+ bool operator <(const sparseCooValue<T> &lhs, const sparseCooValue<T> &rhs) {
1411
+ if (lhs.row < rhs.row ) {
1412
+ return true ;
1413
+ } else if (lhs.row == rhs.row && lhs.col < rhs.col ) {
1414
+ return true ;
1415
+ } else {
1416
+ return false ;
1417
+ }
1418
+ }
1419
+
1420
+ template <typename T>
1421
+ std::ostream &operator <<(std::ostream &os, const sparseCooValue<T> &val) {
1422
+ os << " (" << val.row << " , " << val.col << " ): " << val.value ;
1423
+ return os;
1424
+ }
1425
+
1426
+ template <typename T>
1427
+ bool isZero (const sparseCooValue<T> &val) {
1428
+ return val.value == 0 .;
1429
+ }
1430
+
1431
+ template <typename T>
1432
+ vector<sparseCooValue<T>> toCooVector (const af::array &arr) {
1433
+ vector<sparseCooValue<T>> out;
1434
+ if (arr.issparse ()) {
1435
+ switch (sparseGetStorage (arr)) {
1436
+ case AF_STORAGE_COO: {
1437
+ dim_t nnz = sparseGetNNZ (arr);
1438
+ vector<int > row (nnz), col (nnz);
1439
+ vector<T> values (nnz);
1440
+ sparseGetValues (arr).host (values.data ());
1441
+ sparseGetRowIdx (arr).host (row.data ());
1442
+ sparseGetColIdx (arr).host (col.data ());
1443
+ out.reserve (nnz);
1444
+ for (int i = 0 ; i < nnz; i++) {
1445
+ out.emplace_back (row[i], col[i], values[i]);
1446
+ }
1447
+ } break ;
1448
+ case AF_STORAGE_CSR: {
1449
+ dim_t nnz = sparseGetNNZ (arr);
1450
+ vector<int > row (arr.dims (0 ) + 1 ), col (nnz);
1451
+ vector<T> values (nnz);
1452
+ sparseGetValues (arr).host (values.data ());
1453
+ sparseGetRowIdx (arr).host (row.data ());
1454
+ sparseGetColIdx (arr).host (col.data ());
1455
+ out.reserve (nnz);
1456
+ for (int i = 0 ; i < row.size () - 1 ; i++) {
1457
+ for (int r = row[i]; r < row[i + 1 ]; r++) {
1458
+ out.emplace_back (i, col[r], values[r]);
1459
+ }
1460
+ }
1461
+ } break ;
1462
+ case AF_STORAGE_CSC: {
1463
+ dim_t nnz = sparseGetNNZ (arr);
1464
+ vector<int > row (nnz), col (arr.dims (1 ) + 1 );
1465
+ vector<T> values (nnz);
1466
+ sparseGetValues (arr).host (values.data ());
1467
+ sparseGetRowIdx (arr).host (row.data ());
1468
+ sparseGetColIdx (arr).host (col.data ());
1469
+ out.reserve (nnz);
1470
+ for (int i = 0 ; i < col.size () - 1 ; i++) {
1471
+ for (int c = col[i]; c < col[i + 1 ]; c++) {
1472
+ out.emplace_back (row[c], i, values[c]);
1473
+ }
1474
+ }
1475
+ } break ;
1476
+ default : throw std::logic_error (" NOT SUPPORTED" );
1477
+ }
1478
+ } else {
1479
+ vector<T> values (arr.elements ());
1480
+ arr.host (values.data ());
1481
+ int M = arr.dims (0 ), N = arr.dims (1 );
1482
+ for (int j = 0 ; j < N; j++) {
1483
+ for (int i = 0 ; i < M; i++) {
1484
+ if (std::fpclassify (real (values[j * M + i])) == FP_ZERO) {
1485
+ out.emplace_back (i, j, values[j * M + i]);
1486
+ }
1487
+ }
1488
+ }
1489
+ }
1490
+
1491
+ // Remove zero elements from result to ensure that only non-zero elements
1492
+ // are compared
1493
+ out.erase (std::remove_if (out.begin (), out.end (), isZero<T>), out.end ());
1494
+ std::sort (begin (out), end (out));
1495
+ return out;
1496
+ }
1497
+
1498
+ template <typename T>
1499
+ bool operator ==(const sparseCooValue<T> &lhs, sparseCooValue<T> &rhs) {
1500
+ return lhs.row == rhs.row && lhs.col == rhs.col &&
1501
+ cmp (lhs.value , rhs.value );
1502
+ }
1503
+
1393
1504
template <typename T>
1394
1505
std::string printContext (const std::vector<T> &hGold, std::string goldName,
1395
1506
const std::vector<T> &hOut, std::string outName,
@@ -1495,13 +1606,100 @@ std::string printContext(const std::vector<T> &hGold, std::string goldName,
1495
1606
return os.str ();
1496
1607
}
1497
1608
1609
+ template <typename T>
1610
+ std::string printContext (const std::vector<sparseCooValue<T>> &hGold,
1611
+ std::string goldName,
1612
+ const std::vector<sparseCooValue<T>> &hOut,
1613
+ std::string outName, af::dim4 arrDims,
1614
+ af::dim4 arrStrides, dim_t idx) {
1615
+ std::ostringstream os;
1616
+
1617
+ af::dim4 coords = unravelIdx (idx, arrDims, arrStrides);
1618
+ dim_t ctxWidth = 5 ;
1619
+
1620
+ // Coordinates that span dim0
1621
+ af::dim4 coordsMinBound = coords;
1622
+ coordsMinBound[0 ] = 0 ;
1623
+ af::dim4 coordsMaxBound = coords;
1624
+ coordsMaxBound[0 ] = arrDims[0 ] - 1 ;
1625
+
1626
+ // dim0 positions that can be displayed
1627
+ dim_t dim0Start = std::max<dim_t >(0LL , idx - ctxWidth);
1628
+ dim_t dim0End = std::min<dim_t >(idx + ctxWidth + 1LL , hGold.size ());
1629
+
1630
+ int setwval = 9 ;
1631
+ // Linearized indices of values in vectors that can be displayed
1632
+ dim_t vecStartIdx =
1633
+ std::max<dim_t >(ravelIdx (coordsMinBound, arrStrides), idx - ctxWidth);
1634
+ os << " Idx: " ;
1635
+ for (int elem = dim0Start; elem < dim0End; elem++) {
1636
+ if (elem == idx) {
1637
+ os << std::setw (setwval - 2 ) << " [" << elem << " ]" ;
1638
+ } else {
1639
+ os << std::setw (setwval) << elem;
1640
+ }
1641
+ }
1642
+ os << " \n Row: " ;
1643
+ for (int elem = dim0Start; elem < dim0End; elem++) {
1644
+ if (elem == idx) {
1645
+ os << std::setw (setwval - 2 ) << " [" << hGold[elem].row << " ]" ;
1646
+ } else {
1647
+ os << std::setw (setwval) << hGold[elem].row ;
1648
+ }
1649
+ }
1650
+ os << " \n " ;
1651
+ for (int elem = dim0Start; elem < dim0End; elem++) {
1652
+ if (elem == idx) {
1653
+ os << std::setw (setwval - 2 ) << " [" << hOut[elem].row << " ]" ;
1654
+ } else {
1655
+ os << std::setw (setwval) << hOut[elem].row ;
1656
+ }
1657
+ }
1658
+ os << " \n Col: " ;
1659
+ for (int elem = dim0Start; elem < dim0End; elem++) {
1660
+ if (elem == idx) {
1661
+ os << std::setw (setwval - 2 ) << " [" << hGold[elem].col << " ]" ;
1662
+ } else {
1663
+ os << std::setw (setwval) << hGold[elem].col ;
1664
+ }
1665
+ }
1666
+ os << " \n " ;
1667
+ for (int elem = dim0Start; elem < dim0End; elem++) {
1668
+ if (elem == idx) {
1669
+ os << std::setw (setwval - 2 ) << " [" << hOut[elem].col << " ]" ;
1670
+ } else {
1671
+ os << std::setw (setwval) << hOut[elem].col ;
1672
+ }
1673
+ }
1674
+
1675
+ os << " \n Value: " ;
1676
+ for (int elem = dim0Start; elem < dim0End; elem++) {
1677
+ if (elem == idx) {
1678
+ os << std::setw (setwval - 2 ) << " [" << hGold[elem].value << " ]" ;
1679
+ } else {
1680
+ os << std::setw (setwval) << hGold[elem].value ;
1681
+ }
1682
+ }
1683
+ os << " \n " ;
1684
+ for (int elem = dim0Start; elem < dim0End; elem++) {
1685
+ if (elem == idx) {
1686
+ os << std::setw (setwval - 2 ) << " [" << hOut[elem].value << " ]" ;
1687
+ } else {
1688
+ os << std::setw (setwval) << hOut[elem].value ;
1689
+ }
1690
+ }
1691
+
1692
+ return os.str ();
1693
+ }
1694
+
1498
1695
template <typename T>
1499
1696
::testing::AssertionResult elemWiseEq (std::string aName, std::string bName,
1500
1697
const std::vector<T> &a, af::dim4 aDims,
1501
1698
const std::vector<T> &b, af::dim4 bDims,
1502
1699
float maxAbsDiff, IntegerTag) {
1503
1700
UNUSED (maxAbsDiff);
1504
1701
typedef typename std::vector<T>::const_iterator iter;
1702
+
1505
1703
std::pair<iter, iter> mismatches =
1506
1704
std::mismatch (a.begin (), a.end (), b.begin ());
1507
1705
iter bItr = mismatches.second ;
@@ -1525,7 +1723,7 @@ struct absMatch {
1525
1723
absMatch (float diff) : diff_(diff) {}
1526
1724
1527
1725
template <typename T>
1528
- bool operator ()(T lhs, T rhs) {
1726
+ bool operator ()(const T & lhs, const T & rhs) const {
1529
1727
if (diff_ > 0 ) {
1530
1728
using half_float::abs;
1531
1729
using std::abs;
@@ -1537,25 +1735,26 @@ struct absMatch {
1537
1735
};
1538
1736
1539
1737
template <>
1540
- bool absMatch::operator ()<af::af_cfloat>(af::af_cfloat lhs, af::af_cfloat rhs) {
1738
+ bool absMatch::operator ()<af::af_cfloat>(const af::af_cfloat &lhs,
1739
+ const af::af_cfloat &rhs) const {
1541
1740
return af::abs (rhs - lhs) <= diff_;
1542
1741
}
1543
1742
1544
1743
template <>
1545
- bool absMatch::operator ()<af::af_cdouble>(af::af_cdouble lhs,
1546
- af::af_cdouble rhs) {
1744
+ bool absMatch::operator ()<af::af_cdouble>(const af::af_cdouble & lhs,
1745
+ const af::af_cdouble & rhs) const {
1547
1746
return af::abs (rhs - lhs) <= diff_;
1548
1747
}
1549
1748
1550
1749
template <>
1551
- bool absMatch::operator ()<std::complex<float>>(std:: complex < float > lhs,
1552
- std::complex <float > rhs) {
1750
+ bool absMatch::operator ()<std::complex<float>>(
1751
+ const std:: complex < float > &lhs, const std::complex <float > & rhs) const {
1553
1752
return std::abs (rhs - lhs) <= diff_;
1554
1753
}
1555
1754
1556
1755
template <>
1557
- bool absMatch::operator ()<std::complex<double>>(std:: complex < double > lhs,
1558
- std::complex <double > rhs) {
1756
+ bool absMatch::operator ()<std::complex<double>>(
1757
+ const std:: complex < double > &lhs, const std::complex <double > & rhs) const {
1559
1758
return std::abs (rhs - lhs) <= diff_;
1560
1759
}
1561
1760
@@ -1597,6 +1796,53 @@ ::testing::AssertionResult elemWiseEq(std::string aName, std::string bName,
1597
1796
}
1598
1797
}
1599
1798
1799
+ template <typename T>
1800
+ ::testing::AssertionResult elemWiseEq (std::string aName, std::string bName,
1801
+ const std::vector<sparseCooValue<T>> &a,
1802
+ af::dim4 aDims,
1803
+ const std::vector<sparseCooValue<T>> &b,
1804
+ af::dim4 bDims, float maxAbsDiff,
1805
+ IntegerTag) {
1806
+ return ::testing::AssertionFailure () << " Unsupported sparse type\n " ;
1807
+ }
1808
+ template <typename T>
1809
+ ::testing::AssertionResult elemWiseEq (std::string aName, std::string bName,
1810
+ const std::vector<sparseCooValue<T>> &a,
1811
+ af::dim4 aDims,
1812
+ const std::vector<sparseCooValue<T>> &b,
1813
+ af::dim4 bDims, float maxAbsDiff,
1814
+ FloatTag) {
1815
+ typedef typename std::vector<sparseCooValue<T>>::const_iterator iter;
1816
+ // TODO(mark): Modify equality for float
1817
+
1818
+ const absMatch diff (maxAbsDiff);
1819
+ std::pair<iter, iter> mismatches = std::mismatch (
1820
+ a.begin (), a.end (), b.begin (),
1821
+ [&diff](const sparseCooValue<T> &lhs, const sparseCooValue<T> &rhs) {
1822
+ return lhs.row == rhs.row && lhs.col == rhs.col &&
1823
+ diff (lhs.value , rhs.value );
1824
+ });
1825
+
1826
+ iter aItr = mismatches.first ;
1827
+ iter bItr = mismatches.second ;
1828
+
1829
+ if (aItr == a.end ()) {
1830
+ return ::testing::AssertionSuccess ();
1831
+ } else {
1832
+ dim_t idx = std::distance (b.begin (), bItr);
1833
+ af::dim4 coords = unravelIdx (idx, bDims, calcStrides (bDims));
1834
+
1835
+ af::dim4 aStrides = calcStrides (aDims);
1836
+
1837
+ ::testing::AssertionResult result =
1838
+ ::testing::AssertionFailure ()
1839
+ << "VALUE DIFFERS at " << idx << ":\n"
1840
+ << printContext(a, aName, b, bName, aDims, aStrides, idx);
1841
+
1842
+ return result;
1843
+ }
1844
+ }
1845
+
1600
1846
template <typename T>
1601
1847
::testing::AssertionResult elemWiseEq (std::string aName, std::string bName,
1602
1848
const af::array &a, const af::array &b,
@@ -1606,13 +1852,21 @@ ::testing::AssertionResult elemWiseEq(std::string aName, std::string bName,
1606
1852
FloatTag, IntegerTag>::type TagType;
1607
1853
TagType tag;
1608
1854
1609
- std::vector<T> hA (static_cast <size_t >(a.elements ()));
1610
- a.host (hA.data ());
1855
+ if (a.issparse () || b.issparse ()) {
1856
+ vector<sparseCooValue<T>> hA = toCooVector<T>(a);
1857
+ vector<sparseCooValue<T>> hB = toCooVector<T>(b);
1611
1858
1612
- std::vector<T> hB (static_cast <size_t >(b.elements ()));
1613
- b.host (hB.data ());
1614
- return elemWiseEq<T>(aName, bName, hA, a.dims (), hB, b.dims (), maxAbsDiff,
1615
- tag);
1859
+ return elemWiseEq<T>(aName, bName, hA, a.dims (), hB, b.dims (),
1860
+ maxAbsDiff, tag);
1861
+ } else {
1862
+ std::vector<T> hA (static_cast <size_t >(a.elements ()));
1863
+ a.host (hA.data ());
1864
+
1865
+ std::vector<T> hB (static_cast <size_t >(b.elements ()));
1866
+ b.host (hB.data ());
1867
+ return elemWiseEq<T>(aName, bName, hA, a.dims (), hB, b.dims (),
1868
+ maxAbsDiff, tag);
1869
+ }
1616
1870
}
1617
1871
1618
1872
template <typename T>
0 commit comments