Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 06ef89f

Browse files
MaanasAroraseberg
andauthored
ENH: Descending sorts with built-in dtype implementations (#31345)
This PR rewires all non-generic sorts to go through the newer API layer (but still fills the old slots). This then enables implementing `descending=True` for all of these, although the generic sorts will currently fail as they use the legacy fallback path that doesn't support descending. Another big change is that the setup is moved to C++ removing the need to explicitly instantiate all templates in C++ and as part of this the sort files are now split into `.cpp` and `.hpp` as we are not instantiating the function inside the new `npysort_methods.cpp` which gets them from the `.hpp` (this split is the biggest cause of the diff). Otherwise a `reverse` template parameter is threaded through everything to implement both sorts and we use `cmp<Tag, reverse>` or similar to switch This PR does not yet set up descending=True SIMD sorts for x86-simd-sorts because it currently uses a different NaN order convention (although this is easily fixable). Co-authored-by: Sebastian Berg <[email protected]>
1 parent 50c9f0f commit 06ef89f

38 files changed

Lines changed: 4626 additions & 4990 deletions
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
New `descending` keyword argument for ``numpy.sort`` and ``numpy.argsort``
2+
--------------------------------------------------------------------------
3+
Users can now pass the `descending=True` keyword argument to ``numpy.sort`` and ``numpy.argsort``
4+
to sort and argsort arrays in descending order. NaN values, if present, are sorted to the end
5+
of the array in both ascending and descending sorts. This feature is available for all built-in
6+
dtypes except `void`, `object`, and `generic`. Note that SIMD optimizations for sorting are
7+
currently not available for descending sorts, so performance may be slower.

numpy/__init__.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1733,6 +1733,7 @@ class _ArrayOrScalarCommon:
17331733
order: str | Sequence[str] | None = ...,
17341734
*,
17351735
stable: py_bool | None = ...,
1736+
descending: py_bool | None = ...,
17361737
) -> NDArray[intp]: ...
17371738

17381739
@overload # axis=None (default), out=None (default), keepdims=False (default)
@@ -3741,6 +3742,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
37413742
order: str | Sequence[str] | None = None,
37423743
*,
37433744
stable: py_bool | None = None,
3745+
descending: py_bool | None = None,
37443746
) -> None: ...
37453747

37463748
# Keep in sync with `MaskedArray.trace`
@@ -5068,7 +5070,7 @@ class generic(_ArrayOrScalarCommon, Generic[_ItemT_co]):
50685070
) -> Never: ...
50695071
def diagonal(self: Never, /, offset: L[0] = 0, axis1: L[0] = 0, axis2: L[1] = 1) -> Never: ... # type: ignore[misc]
50705072
def swapaxes(self: Never, axis1: Never, axis2: Never, /) -> Never: ... # type: ignore[misc]
5071-
def sort(self: Never, /, axis: L[-1] = -1, kind: None = None, order: None = None, *, stable: None = None) -> Never: ... # type: ignore[misc]
5073+
def sort(self: Never, /, axis: L[-1] = -1, kind: None = None, order: None = None, *, stable: None = None, descending: None = None) -> Never: ... # type: ignore[misc]
50725074
def nonzero(self: Never, /) -> Never: ... # type: ignore[misc]
50735075
def setfield(self: Never, val: Never, /, dtype: Never, offset: L[0] = 0) -> None: ... # type: ignore[misc]
50745076
def searchsorted(self: Never, v: Never, /, side: L["left"] = "left", sorter: None = None) -> Never: ... # type: ignore[misc]

numpy/_core/_add_newdocs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3488,9 +3488,9 @@ def _array_method_doc(name: str, params: str, doc: str) -> None:
34883488
numpy.argmin : equivalent function
34893489
""")
34903490

3491-
_array_method_doc('argsort', "axis=-1, kind=None, order=None, *, stable=None",
3491+
_array_method_doc('argsort', "axis=-1, kind=None, order=None, *, stable=None, descending=None",
34923492
"""
3493-
a.argsort(axis=-1, kind=None, order=None, *, stable=None)
3493+
a.argsort(axis=-1, kind=None, order=None, *, stable=None, descending=None)
34943494
34953495
Returns the indices that would sort this array.
34963496
@@ -4413,9 +4413,9 @@ def _array_method_doc(name: str, params: str, doc: str) -> None:
44134413
ValueError: cannot set WRITEBACKIFCOPY flag to True
44144414
""")
44154415

4416-
_array_method_doc('sort', "axis=-1, kind=None, order=None, *, stable=None",
4416+
_array_method_doc('sort', "axis=-1, kind=None, order=None, *, stable=None, descending=None",
44174417
"""
4418-
a.sort(axis=-1, kind=None, order=None, *, stable=None)
4418+
a.sort(axis=-1, kind=None, order=None, *, stable=None, descending=None)
44194419
44204420
Sort an array in-place. Refer to `numpy.sort` for full documentation.
44214421

numpy/_core/defchararray.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def __mod__(self, i):
722722
def __rmod__(self, other):
723723
return NotImplemented
724724

725-
def argsort(self, axis=-1, kind=None, order=None, *, stable=None):
725+
def argsort(self, axis=-1, kind=None, order=None, *, stable=None, descending=None):
726726
"""
727727
Return the indices that sort the array lexicographically.
728728
@@ -740,7 +740,10 @@ def argsort(self, axis=-1, kind=None, order=None, *, stable=None):
740740
dtype='|S5')
741741
742742
"""
743-
return self.__array__().argsort(axis, kind, order, stable=stable)
743+
return self.__array__().argsort(
744+
axis, kind, order, stable=stable, descending=descending
745+
)
746+
744747
argsort.__doc__ = ndarray.argsort.__doc__
745748

746749
def capitalize(self):

numpy/_core/fromnumeric.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -932,12 +932,14 @@ def argpartition(a, kth, axis=-1, kind='introselect', order=None):
932932
return _wrapfunc(a, 'argpartition', kth, axis=axis, kind=kind, order=order)
933933

934934

935-
def _sort_dispatcher(a, axis=None, kind=None, order=None, *, stable=None):
935+
def _sort_dispatcher(
936+
a, axis=None, kind=None, order=None, *, stable=None, descending=None
937+
):
936938
return (a,)
937939

938940

939941
@array_function_dispatch(_sort_dispatcher)
940-
def sort(a, axis=-1, kind=None, order=None, *, stable=None):
942+
def sort(a, axis=-1, kind=None, order=None, *, stable=None, descending=None):
941943
"""
942944
Return a sorted copy of an array.
943945
@@ -966,6 +968,13 @@ def sort(a, axis=-1, kind=None, order=None, *, stable=None):
966968
this option selects ``kind='stable'``. Default: ``None``.
967969
968970
.. versionadded:: 2.0.0
971+
descending : bool, optional
972+
Sort order. If ``True``, the returned array will be sorted in
973+
descending order. If ``False`` or ``None``, the returned array will
974+
be sorted in ascending order. Values that are NaN are sorted to the
975+
end for both orders. Default: ``None``.
976+
977+
.. versionadded:: 2.5.0
969978
970979
Returns
971980
-------
@@ -1089,16 +1098,18 @@ def sort(a, axis=-1, kind=None, order=None, *, stable=None):
10891098
axis = -1
10901099
else:
10911100
a = asanyarray(a).copy(order="K")
1092-
a.sort(axis=axis, kind=kind, order=order, stable=stable)
1101+
a.sort(axis=axis, kind=kind, order=order, stable=stable, descending=descending)
10931102
return a
10941103

10951104

1096-
def _argsort_dispatcher(a, axis=None, kind=None, order=None, *, stable=None):
1105+
def _argsort_dispatcher(
1106+
a, axis=None, kind=None, order=None, *, stable=None, descending=None
1107+
):
10971108
return (a,)
10981109

10991110

11001111
@array_function_dispatch(_argsort_dispatcher)
1101-
def argsort(a, axis=-1, kind=None, order=None, *, stable=None):
1112+
def argsort(a, axis=-1, kind=None, order=None, *, stable=None, descending=None):
11021113
"""
11031114
Returns the indices that would sort an array.
11041115
@@ -1131,6 +1142,13 @@ def argsort(a, axis=-1, kind=None, order=None, *, stable=None):
11311142
this option selects ``kind='stable'``. Default: ``None``.
11321143
11331144
.. versionadded:: 2.0.0
1145+
descending : bool, optional
1146+
Sort order. If ``True``, the returned array will be sorted in
1147+
descending order. If ``False`` or ``None``, the returned array will
1148+
be sorted in ascending order. Values that are NaN are sorted to the
1149+
end for both orders. Default: ``None``.
1150+
1151+
.. versionadded:: 2.5.0
11341152
11351153
Returns
11361154
-------
@@ -1211,7 +1229,13 @@ def argsort(a, axis=-1, kind=None, order=None, *, stable=None):
12111229
12121230
"""
12131231
return _wrapfunc(
1214-
a, 'argsort', axis=axis, kind=kind, order=order, stable=stable
1232+
a,
1233+
"argsort",
1234+
axis=axis,
1235+
kind=kind,
1236+
order=order,
1237+
stable=stable,
1238+
descending=descending,
12151239
)
12161240

12171241
def _argmax_dispatcher(a, axis=None, out=None, *, keepdims=np._NoValue):

numpy/_core/fromnumeric.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ def sort[ArrayT: np.ndarray](
453453
order: str | Sequence[str] | None = None,
454454
*,
455455
stable: bool | None = None,
456+
descending: bool | None = None,
456457
) -> ArrayT: ...
457458
@overload
458459
def sort[ScalarT: np.generic](
@@ -462,6 +463,7 @@ def sort[ScalarT: np.generic](
462463
order: str | Sequence[str] | None = None,
463464
*,
464465
stable: bool | None = None,
466+
descending: bool | None = None,
465467
) -> NDArray[ScalarT]: ...
466468
@overload
467469
def sort[ScalarT: np.generic](
@@ -471,6 +473,7 @@ def sort[ScalarT: np.generic](
471473
order: str | Sequence[str] | None = None,
472474
*,
473475
stable: bool | None = None,
476+
descending: bool | None = None,
474477
) -> _Array1D[ScalarT]: ...
475478
@overload
476479
def sort(
@@ -480,6 +483,7 @@ def sort(
480483
order: str | Sequence[str] | None = None,
481484
*,
482485
stable: bool | None = None,
486+
descending: bool | None = None,
483487
) -> NDArray[Any]: ...
484488
@overload
485489
def sort(
@@ -489,6 +493,7 @@ def sort(
489493
order: str | Sequence[str] | None = None,
490494
*,
491495
stable: bool | None = None,
496+
descending: bool | None = None,
492497
) -> _Array1D[Any]: ...
493498

494499
#
@@ -500,6 +505,7 @@ def argsort[ShapeT: _Shape](
500505
order: str | Sequence[str] | None = None,
501506
*,
502507
stable: bool | None = None,
508+
descending: bool | None = None,
503509
) -> np.ndarray[ShapeT, np.dtype[np.intp]]: ...
504510
@overload
505511
def argsort(
@@ -509,6 +515,7 @@ def argsort(
509515
order: str | Sequence[str] | None = None,
510516
*,
511517
stable: bool | None = None,
518+
descending: bool | None = None,
512519
) -> NDArray[np.intp]: ...
513520
@overload
514521
def argsort(
@@ -518,6 +525,7 @@ def argsort(
518525
order: str | Sequence[str] | None = None,
519526
*,
520527
stable: bool | None = None,
528+
descending: bool | None = None,
521529
) -> _Array1D[np.intp]: ...
522530

523531
# keep in sync with `argmin` below

numpy/_core/meson.build

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,11 +1232,11 @@ src_multiarray = multiarray_gen_headers + [
12321232
'src/multiarray/temp_elide.c',
12331233
'src/multiarray/usertypes.c',
12341234
'src/multiarray/vdot.c',
1235-
'src/npysort/quicksort.cpp',
1235+
'src/npysort/quicksort_generic.cpp',
12361236
'src/npysort/mergesort.cpp',
1237-
'src/npysort/timsort.cpp',
1237+
'src/npysort/timsort_generic.cpp',
12381238
'src/npysort/heapsort.cpp',
1239-
'src/npysort/radixsort.cpp',
1239+
'src/npysort/npysort_methods.cpp',
12401240
'src/common/npy_partition.h',
12411241
'src/npysort/selection.cpp',
12421242
'src/common/npy_binsearch.h',

numpy/_core/src/common/npy_sort.h.src

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -24,72 +24,13 @@ extern "C" {
2424
#endif
2525

2626

27-
28-
/*
29-
*****************************************************************************
30-
** NUMERIC SORTS **
31-
*****************************************************************************
32-
*/
33-
34-
35-
/**begin repeat
36-
*
37-
* #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong,
38-
* longlong, ulonglong, half, float, double, longdouble,
39-
* cfloat, cdouble, clongdouble, datetime, timedelta#
40-
*/
41-
42-
NPY_NO_EXPORT int quicksort_@suff@(void *vec, npy_intp cnt, void *null);
43-
NPY_NO_EXPORT int heapsort_@suff@(void *vec, npy_intp cnt, void *null);
44-
NPY_NO_EXPORT int mergesort_@suff@(void *vec, npy_intp cnt, void *null);
45-
NPY_NO_EXPORT int timsort_@suff@(void *vec, npy_intp cnt, void *null);
46-
NPY_NO_EXPORT int aquicksort_@suff@(void *vec, npy_intp *ind, npy_intp cnt, void *null);
47-
NPY_NO_EXPORT int aheapsort_@suff@(void *vec, npy_intp *ind, npy_intp cnt, void *null);
48-
NPY_NO_EXPORT int amergesort_@suff@(void *vec, npy_intp *ind, npy_intp cnt, void *null);
49-
NPY_NO_EXPORT int atimsort_@suff@(void *vec, npy_intp *ind, npy_intp cnt, void *null);
50-
51-
/**end repeat**/
52-
53-
/**begin repeat
54-
*
55-
* #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong,
56-
* longlong, ulonglong#
57-
*/
58-
#ifdef __cplusplus
59-
extern "C" {
60-
#endif
61-
NPY_NO_EXPORT int radixsort_@suff@(void *vec, npy_intp cnt, void *null);
62-
NPY_NO_EXPORT int aradixsort_@suff@(void *vec, npy_intp *ind, npy_intp cnt, void *null);
63-
#ifdef __cplusplus
64-
}
65-
#endif
66-
67-
/**end repeat**/
68-
69-
70-
7127
/*
7228
*****************************************************************************
73-
** STRING SORTS **
29+
** NEW SORT METHOD REGISTRATIONS **
7430
*****************************************************************************
7531
*/
7632

77-
78-
/**begin repeat
79-
*
80-
* #suff = string, unicode#
81-
*/
82-
83-
NPY_NO_EXPORT int quicksort_@suff@(void *vec, npy_intp cnt, void *arr);
84-
NPY_NO_EXPORT int heapsort_@suff@(void *vec, npy_intp cnt, void *arr);
85-
NPY_NO_EXPORT int mergesort_@suff@(void *vec, npy_intp cnt, void *arr);
86-
NPY_NO_EXPORT int timsort_@suff@(void *vec, npy_intp cnt, void *arr);
87-
NPY_NO_EXPORT int aquicksort_@suff@(void *vec, npy_intp *ind, npy_intp cnt, void *arr);
88-
NPY_NO_EXPORT int aheapsort_@suff@(void *vec, npy_intp *ind, npy_intp cnt, void *arr);
89-
NPY_NO_EXPORT int amergesort_@suff@(void *vec, npy_intp *ind, npy_intp cnt, void *arr);
90-
NPY_NO_EXPORT int atimsort_@suff@(void *vec, npy_intp *ind, npy_intp cnt, void *arr);
91-
92-
/**end repeat**/
33+
NPY_NO_EXPORT int register_all_sorts(void);
9334

9435

9536
/*

0 commit comments

Comments
 (0)