@@ -2424,10 +2424,7 @@ PyArray_Nonzero(PyArrayObject *self)
2424
2424
PyObject * ret_tuple ;
2425
2425
npy_intp ret_dims [2 ];
2426
2426
PyArray_NonzeroFunc * nonzero = PyArray_DESCR (self )-> f -> nonzero ;
2427
- char * data ;
2428
- npy_intp stride , count ;
2429
2427
npy_intp nonzero_count ;
2430
- npy_intp * multi_index ;
2431
2428
2432
2429
NpyIter * iter ;
2433
2430
NpyIter_IterNextFunc * iternext ;
@@ -2454,23 +2451,47 @@ PyArray_Nonzero(PyArrayObject *self)
2454
2451
2455
2452
/* If it's a one-dimensional result, don't use an iterator */
2456
2453
if (ndim <= 1 ) {
2457
- npy_intp j ;
2454
+ npy_intp * multi_index = (npy_intp * )PyArray_DATA (ret );
2455
+ char * data = PyArray_BYTES (self );
2456
+ npy_intp stride = (ndim == 0 ) ? 0 : PyArray_STRIDE (self , 0 );
2457
+ npy_intp count = (ndim == 0 ) ? 1 : PyArray_DIM (self , 0 );
2458
2458
2459
- multi_index = ( npy_intp * ) PyArray_DATA ( ret );
2460
- data = PyArray_BYTES ( self );
2461
- stride = ( ndim == 0 ) ? 0 : PyArray_STRIDE ( self , 0 ) ;
2462
- count = ( ndim == 0 ) ? 1 : PyArray_DIM ( self , 0 );
2459
+ /* nothing to do */
2460
+ if ( nonzero_count == 0 ) {
2461
+ goto finish ;
2462
+ }
2463
2463
2464
+ /* avoid function call for bool */
2464
2465
if (PyArray_ISBOOL (self )) {
2465
- /* avoid function call for bool */
2466
- for (j = 0 ; j < count ; ++ j ) {
2467
- if (* data != 0 ) {
2468
- * multi_index ++ = j ;
2466
+ /*
2467
+ * use fast memchr variant for sparse data, see gh-4370
2468
+ * the fast bool count is followed by this sparse path is faster
2469
+ * than combining the two loops, even for larger arrays
2470
+ */
2471
+ if (((double )nonzero_count / count ) <= 0.1 ) {
2472
+ npy_intp subsize ;
2473
+ npy_intp j = 0 ;
2474
+ while (1 ) {
2475
+ npy_memchr (data + j * stride , 0 , stride , count , & subsize , 1 );
2476
+ j += subsize ;
2477
+ if (j >= count ) {
2478
+ break ;
2479
+ }
2480
+ * multi_index ++ = j ++ ;
2481
+ }
2482
+ }
2483
+ else {
2484
+ npy_intp j ;
2485
+ for (j = 0 ; j < count ; ++ j ) {
2486
+ if (* data != 0 ) {
2487
+ * multi_index ++ = j ;
2488
+ }
2489
+ data += stride ;
2469
2490
}
2470
- data += stride ;
2471
2491
}
2472
2492
}
2473
2493
else {
2494
+ npy_intp j ;
2474
2495
for (j = 0 ; j < count ; ++ j ) {
2475
2496
if (nonzero (data , self )) {
2476
2497
* multi_index ++ = j ;
@@ -2498,6 +2519,7 @@ PyArray_Nonzero(PyArrayObject *self)
2498
2519
}
2499
2520
2500
2521
if (NpyIter_GetIterSize (iter ) != 0 ) {
2522
+ npy_intp * multi_index ;
2501
2523
/* Get the pointers for inner loop iteration */
2502
2524
iternext = NpyIter_GetIterNext (iter , NULL );
2503
2525
if (iternext == NULL ) {
@@ -2558,7 +2580,7 @@ PyArray_Nonzero(PyArrayObject *self)
2558
2580
else {
2559
2581
for (i = 0 ; i < ndim ; ++ i ) {
2560
2582
PyArrayObject * view ;
2561
- stride = ndim * NPY_SIZEOF_INTP ;
2583
+ npy_intp stride = ndim * NPY_SIZEOF_INTP ;
2562
2584
2563
2585
view = (PyArrayObject * )PyArray_New (Py_TYPE (self ), 1 ,
2564
2586
& nonzero_count ,
0 commit comments