Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a676f4e

Browse files
committed
ENH: add max fastpath to partition for nan detection
Allows low overhead check for NaN via isnan(partition(d, (x, -1))[-1])
1 parent dd85748 commit a676f4e

2 files changed

Lines changed: 25 additions & 0 deletions

File tree

numpy/core/src/npysort/selection.c.src

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ static NPY_INLINE void store_pivot(npy_intp pivot, npy_intp kth,
6969
* npy_uint, npy_long, npy_ulong, npy_longlong, npy_ulonglong,
7070
* npy_ushort, npy_float, npy_double, npy_longdouble, npy_cfloat,
7171
* npy_cdouble, npy_clongdouble#
72+
* #inexact = 0*11, 1*7#
7273
*/
7374

7475
static npy_intp
@@ -322,6 +323,20 @@ int
322323
store_pivot(kth, kth, pivots, npiv);
323324
return 0;
324325
}
326+
else if (@inexact@ && kth == num - 1) {
327+
/* useful to check if NaN present via partition(d, (x, -1)) */
328+
npy_intp k;
329+
npy_intp maxidx = low;
330+
@type@ maxval = v[IDX(low)];
331+
for (k = low + 1; k < num; k++) {
332+
if (!@TYPE@_LT(v[IDX(k)], maxval)) {
333+
maxidx = k;
334+
maxval = v[IDX(k)];
335+
}
336+
}
337+
SWAP(SORTEE(kth), SORTEE(maxidx));
338+
return 0;
339+
}
325340

326341
/* dumb integer msb, float npy_log2 too slow for small parititions */
327342
{

numpy/core/tests/test_multiarray.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,16 @@ def test_partition(self):
12991299
mid = x.size // 2 + 1
13001300
assert_equal(np.partition(x, mid)[mid], mid)
13011301

1302+
# max
1303+
d = np.ones(10); d[1] = 4;
1304+
assert_equal(np.partition(d, (2, -1))[-1], 4)
1305+
assert_equal(np.partition(d, (2, -1))[2], 1)
1306+
assert_equal(d[np.argpartition(d, (2, -1))][-1], 4)
1307+
assert_equal(d[np.argpartition(d, (2, -1))][2], 1)
1308+
d[1] = np.nan
1309+
assert_(np.isnan(d[np.argpartition(d, (2, -1))][-1]))
1310+
assert_(np.isnan(np.partition(d, (2, -1))[-1]))
1311+
13021312
# equal elements
13031313
d = np.arange((47)) % 7
13041314
tgt = np.sort(np.arange((47)) % 7)

0 commit comments

Comments
 (0)