Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 723f2fb

Browse files
pllimseberg
andauthored
BUG: Do not forward descending unless explicitly passed (#31458)
Instead use np._NoValue and only forward to the method when descending was explicitly passed by the user. Co-authored-by: Sebastian Berg <[email protected]>
1 parent 19bb277 commit 723f2fb

3 files changed

Lines changed: 46 additions & 12 deletions

File tree

numpy/_core/fromnumeric.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ def _sort_dispatcher(
939939

940940

941941
@array_function_dispatch(_sort_dispatcher)
942-
def sort(a, axis=-1, kind=None, order=None, *, stable=None, descending=None):
942+
def sort(a, axis=-1, kind=None, order=None, *, stable=None, descending=np._NoValue):
943943
"""
944944
Return a sorted copy of an array.
945945
@@ -1098,7 +1098,11 @@ def sort(a, axis=-1, kind=None, order=None, *, stable=None, descending=None):
10981098
axis = -1
10991099
else:
11001100
a = asanyarray(a).copy(order="K")
1101-
a.sort(axis=axis, kind=kind, order=order, stable=stable, descending=descending)
1101+
# Sanitize for backward-compatibility
1102+
if descending is not np._NoValue:
1103+
a.sort(axis=axis, kind=kind, order=order, stable=stable, descending=descending)
1104+
else:
1105+
a.sort(axis=axis, kind=kind, order=order, stable=stable)
11021106
return a
11031107

11041108

@@ -1109,7 +1113,7 @@ def _argsort_dispatcher(
11091113

11101114

11111115
@array_function_dispatch(_argsort_dispatcher)
1112-
def argsort(a, axis=-1, kind=None, order=None, *, stable=None, descending=None):
1116+
def argsort(a, axis=-1, kind=None, order=None, *, stable=None, descending=np._NoValue):
11131117
"""
11141118
Returns the indices that would sort an array.
11151119
@@ -1228,14 +1232,24 @@ def argsort(a, axis=-1, kind=None, order=None, *, stable=None, descending=None):
12281232
array([0, 1])
12291233
12301234
"""
1235+
# Sanitize for backward-compatibility
1236+
if descending is not np._NoValue:
1237+
return _wrapfunc(
1238+
a,
1239+
"argsort",
1240+
axis=axis,
1241+
kind=kind,
1242+
order=order,
1243+
stable=stable,
1244+
descending=descending,
1245+
)
12311246
return _wrapfunc(
12321247
a,
12331248
"argsort",
12341249
axis=axis,
12351250
kind=kind,
12361251
order=order,
12371252
stable=stable,
1238-
descending=descending,
12391253
)
12401254

12411255
def _argmax_dispatcher(a, axis=None, out=None, *, keepdims=np._NoValue):

numpy/_core/fromnumeric.pyi

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def sort[ArrayT: np.ndarray](
453453
order: str | Sequence[str] | None = None,
454454
*,
455455
stable: bool | None = None,
456-
descending: bool | None = None,
456+
descending: bool | _NoValueType = ...,
457457
) -> ArrayT: ...
458458
@overload
459459
def sort[ScalarT: np.generic](
@@ -463,7 +463,7 @@ def sort[ScalarT: np.generic](
463463
order: str | Sequence[str] | None = None,
464464
*,
465465
stable: bool | None = None,
466-
descending: bool | None = None,
466+
descending: bool | _NoValueType = ...,
467467
) -> NDArray[ScalarT]: ...
468468
@overload
469469
def sort[ScalarT: np.generic](
@@ -473,7 +473,7 @@ def sort[ScalarT: np.generic](
473473
order: str | Sequence[str] | None = None,
474474
*,
475475
stable: bool | None = None,
476-
descending: bool | None = None,
476+
descending: bool | _NoValueType = ...,
477477
) -> _Array1D[ScalarT]: ...
478478
@overload
479479
def sort(
@@ -483,7 +483,7 @@ def sort(
483483
order: str | Sequence[str] | None = None,
484484
*,
485485
stable: bool | None = None,
486-
descending: bool | None = None,
486+
descending: bool | _NoValueType = ...,
487487
) -> NDArray[Any]: ...
488488
@overload
489489
def sort(
@@ -493,7 +493,7 @@ def sort(
493493
order: str | Sequence[str] | None = None,
494494
*,
495495
stable: bool | None = None,
496-
descending: bool | None = None,
496+
descending: bool | _NoValueType = ...,
497497
) -> _Array1D[Any]: ...
498498

499499
#
@@ -505,7 +505,7 @@ def argsort[ShapeT: _Shape](
505505
order: str | Sequence[str] | None = None,
506506
*,
507507
stable: bool | None = None,
508-
descending: bool | None = None,
508+
descending: bool | _NoValueType = ...,
509509
) -> np.ndarray[ShapeT, np.dtype[np.intp]]: ...
510510
@overload
511511
def argsort(
@@ -515,7 +515,7 @@ def argsort(
515515
order: str | Sequence[str] | None = None,
516516
*,
517517
stable: bool | None = None,
518-
descending: bool | None = None,
518+
descending: bool | _NoValueType = ...,
519519
) -> NDArray[np.intp]: ...
520520
@overload
521521
def argsort(
@@ -525,7 +525,7 @@ def argsort(
525525
order: str | Sequence[str] | None = None,
526526
*,
527527
stable: bool | None = None,
528-
descending: bool | None = None,
528+
descending: bool | _NoValueType = ...,
529529
) -> _Array1D[np.intp]: ...
530530

531531
# keep in sync with `argmin` below

numpy/_core/tests/test_multiarray.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2752,6 +2752,26 @@ def test_sort_unicode_kind(self):
27522752
assert_raises(ValueError, d.sort, kind=k)
27532753
assert_raises(ValueError, d.argsort, kind=k)
27542754

2755+
@pytest.mark.parametrize("func", [np.sort, np.argsort])
2756+
def test_descending_kwarg_forwarding(self, func):
2757+
# NumPy currently calls the sort/argsort methods. The descending
2758+
# kwarg introduced in 2.5 is only forwarded if True.
2759+
2760+
class MyArray(np.ndarray):
2761+
def sort(self, **kwargs):
2762+
MyArray.kwargs = kwargs # set on class (may use views)
2763+
2764+
def argsort(self, **kwargs):
2765+
MyArray.kwargs = kwargs # set on class (may use views)
2766+
2767+
m = np.array([1, 2, 3]).view(MyArray)
2768+
func(m)
2769+
assert "descending" not in MyArray.kwargs
2770+
func(m, descending=False) # OK if it was passed
2771+
assert not MyArray.kwargs["descending"]
2772+
func(m, descending=True) # must be passed
2773+
assert MyArray.kwargs["descending"]
2774+
27552775
def _test_sort_descending_nonan(self, a, stable, descending):
27562776
if not descending:
27572777
a = a[::-1]

0 commit comments

Comments
 (0)