Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7a2b14a

Browse files
committed
Merge pull request numpy#4837 from juliantaylor/select-bug
BUG: wrong selection for orders falling into equal ranges
2 parents e8d1374 + d6c7a16 commit 7a2b14a

2 files changed

Lines changed: 25 additions & 3 deletions

File tree

numpy/core/src/npysort/selection.c.src

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,10 @@ int
390390
/* move pivot into position */
391391
SWAP(SORTEE(low), SORTEE(hh));
392392

393-
store_pivot(hh, kth, pivots, npiv);
393+
/* kth pivot stored later */
394+
if (hh != kth) {
395+
store_pivot(hh, kth, pivots, npiv);
396+
}
394397

395398
if (hh >= kth)
396399
high = hh - 1;
@@ -400,10 +403,11 @@ int
400403

401404
/* two elements */
402405
if (high == low + 1) {
403-
if (@TYPE@_LT(v[IDX(high)], v[IDX(low)]))
406+
if (@TYPE@_LT(v[IDX(high)], v[IDX(low)])) {
404407
SWAP(SORTEE(high), SORTEE(low))
405-
store_pivot(low, kth, pivots, npiv);
408+
}
406409
}
410+
store_pivot(kth, kth, pivots, npiv);
407411

408412
return 0;
409413
}

numpy/core/tests/test_multiarray.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,6 +1356,12 @@ def test_partition(self):
13561356
d[i:].partition(0, kind=k)
13571357
assert_array_equal(d, tgt)
13581358

1359+
d = np.array([0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
1360+
7, 7, 7, 7, 7, 9])
1361+
kth = [0, 3, 19, 20]
1362+
assert_equal(np.partition(d, kth, kind=k)[kth], (0, 3, 7, 7))
1363+
assert_equal(d[np.argpartition(d, kth, kind=k)][kth], (0, 3, 7, 7))
1364+
13591365
d = np.array([2, 1])
13601366
d.partition(0, kind=k)
13611367
assert_raises(ValueError, d.partition, 2)
@@ -1551,6 +1557,18 @@ def test_partition_unicode_kind(self):
15511557
assert_raises(ValueError, d.partition, 2, kind=k)
15521558
assert_raises(ValueError, d.argpartition, 2, kind=k)
15531559

1560+
def test_partition_fuzz(self):
1561+
# a few rounds of random data testing
1562+
for j in range(10, 30):
1563+
for i in range(1, j - 2):
1564+
d = np.arange(j)
1565+
np.random.shuffle(d)
1566+
d = d % np.random.randint(2, 30)
1567+
idx = np.random.randint(d.size)
1568+
kth = [0, idx, i, i + 1]
1569+
tgt = np.sort(d)[kth]
1570+
assert_array_equal(np.partition(d, kth)[kth], tgt,
1571+
err_msg="data: %r\n kth: %r" % (d, kth))
15541572

15551573
def test_flatten(self):
15561574
x0 = np.array([[1, 2, 3], [4, 5, 6]], np.int32)

0 commit comments

Comments
 (0)