Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 61d7a10

Browse files
committed
ENH: improve PyArray_Nonzero for sparse bool masks
we already count the number of true elements so we can decide when its sparse and use faster (4-byte) vectorized npy_memchr.
1 parent 4da29a8 commit 61d7a10

File tree

2 files changed

+50
-14
lines changed

2 files changed

+50
-14
lines changed

numpy/core/src/multiarray/item_selection.c

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2424,10 +2424,7 @@ PyArray_Nonzero(PyArrayObject *self)
24242424
PyObject *ret_tuple;
24252425
npy_intp ret_dims[2];
24262426
PyArray_NonzeroFunc *nonzero = PyArray_DESCR(self)->f->nonzero;
2427-
char *data;
2428-
npy_intp stride, count;
24292427
npy_intp nonzero_count;
2430-
npy_intp *multi_index;
24312428

24322429
NpyIter *iter;
24332430
NpyIter_IterNextFunc *iternext;
@@ -2454,23 +2451,47 @@ PyArray_Nonzero(PyArrayObject *self)
24542451

24552452
/* If it's a one-dimensional result, don't use an iterator */
24562453
if (ndim <= 1) {
2457-
npy_intp j;
2454+
npy_intp * multi_index = (npy_intp *)PyArray_DATA(ret);
2455+
char * data = PyArray_BYTES(self);
2456+
npy_intp stride = (ndim == 0) ? 0 : PyArray_STRIDE(self, 0);
2457+
npy_intp count = (ndim == 0) ? 1 : PyArray_DIM(self, 0);
24582458

2459-
multi_index = (npy_intp *)PyArray_DATA(ret);
2460-
data = PyArray_BYTES(self);
2461-
stride = (ndim == 0) ? 0 : PyArray_STRIDE(self, 0);
2462-
count = (ndim == 0) ? 1 : PyArray_DIM(self, 0);
2459+
/* nothing to do */
2460+
if (nonzero_count == 0) {
2461+
goto finish;
2462+
}
24632463

2464+
/* avoid function call for bool */
24642465
if (PyArray_ISBOOL(self)) {
2465-
/* avoid function call for bool */
2466-
for (j = 0; j < count; ++j) {
2467-
if (*data != 0) {
2468-
*multi_index++ = j;
2466+
/*
2467+
* use fast memchr variant for sparse data, see gh-4370
2468+
* the fast bool count is followed by this sparse path is faster
2469+
* than combining the two loops, even for larger arrays
2470+
*/
2471+
if (((double)nonzero_count / count) <= 0.1) {
2472+
npy_intp subsize;
2473+
npy_intp j = 0;
2474+
while (1) {
2475+
npy_memchr(data + j * stride, 0, stride, count, &subsize, 1);
2476+
j += subsize;
2477+
if (j >= count) {
2478+
break;
2479+
}
2480+
*multi_index++ = j++;
2481+
}
2482+
}
2483+
else {
2484+
npy_intp j;
2485+
for (j = 0; j < count; ++j) {
2486+
if (*data != 0) {
2487+
*multi_index++ = j;
2488+
}
2489+
data += stride;
24692490
}
2470-
data += stride;
24712491
}
24722492
}
24732493
else {
2494+
npy_intp j;
24742495
for (j = 0; j < count; ++j) {
24752496
if (nonzero(data, self)) {
24762497
*multi_index++ = j;
@@ -2498,6 +2519,7 @@ PyArray_Nonzero(PyArrayObject *self)
24982519
}
24992520

25002521
if (NpyIter_GetIterSize(iter) != 0) {
2522+
npy_intp * multi_index;
25012523
/* Get the pointers for inner loop iteration */
25022524
iternext = NpyIter_GetIterNext(iter, NULL);
25032525
if (iternext == NULL) {
@@ -2558,7 +2580,7 @@ PyArray_Nonzero(PyArrayObject *self)
25582580
else {
25592581
for (i = 0; i < ndim; ++i) {
25602582
PyArrayObject *view;
2561-
stride = ndim*NPY_SIZEOF_INTP;
2583+
npy_intp stride = ndim * NPY_SIZEOF_INTP;
25622584

25632585
view = (PyArrayObject *)PyArray_New(Py_TYPE(self), 1,
25642586
&nonzero_count,

numpy/core/tests/test_numeric.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,20 @@ def test_nonzero_twodim(self):
936936
assert_equal(np.nonzero(x['a'].T), ([0, 1, 1, 2], [1, 1, 2, 0]))
937937
assert_equal(np.nonzero(x['b'].T), ([0, 0, 1, 2, 2], [0, 1, 2, 0, 2]))
938938

939+
def test_sparse(self):
940+
# test special sparse condition boolean code path
941+
for i in range(20):
942+
c = np.zeros(200, dtype=np.bool)
943+
c[i::20] = True
944+
assert_equal(np.nonzero(c)[0], np.arange(i, 200 + i, 20))
945+
946+
c = np.zeros(400, dtype=np.bool)
947+
c[10 + i:20 + i] = True
948+
c[20 + i*2] = True
949+
assert_equal(np.nonzero(c)[0],
950+
np.concatenate((np.arange(10 +i, 20 + i), [20 +i*2])))
951+
952+
939953
class TestIndex(TestCase):
940954
def test_boolean(self):
941955
a = rand(3, 5, 8)

0 commit comments

Comments
 (0)