diff --git a/numpy/ma/core.pyi b/numpy/ma/core.pyi index 26d05880c830..9fe2a2df0efa 100644 --- a/numpy/ma/core.pyi +++ b/numpy/ma/core.pyi @@ -9,6 +9,7 @@ from typing_extensions import deprecated from numpy import ( _OrderKACF, + _PartitionKind, _SortKind, amax, amin, @@ -26,6 +27,7 @@ from numpy._typing import ( NDArray, _ArrayLike, _DTypeLikeBool, + _ArrayLikeInt, _ScalarLike_co, _Shape, _ShapeLike, @@ -665,8 +667,20 @@ class MaskedArray(ndarray[_ShapeType_co, _DType_co]): ) -> _ArrayT: ... # - def partition(self, *args, **kwargs): ... - def argpartition(self, *args, **kwargs): ... + def partition( + self, + kth: _ArrayLikeInt, + axis: SupportsIndex = -1, + kind: _PartitionKind = "introselect", + order: str | Sequence[str] | None = None + ) -> None: ... + def argpartition( + self, + kth: _ArrayLikeInt, + axis: SupportsIndex = -1, + kind: _PartitionKind = "introselect", + order: str | Sequence[str] | None = None + ) -> _MaskedArray[intp]: ... def take(self, indices, axis=..., out=..., mode=...): ... copy: Any diff --git a/numpy/typing/tests/data/fail/ma.pyi b/numpy/typing/tests/data/fail/ma.pyi index a2ccd9218002..9894f51ab36e 100644 --- a/numpy/typing/tests/data/fail/ma.pyi +++ b/numpy/typing/tests/data/fail/ma.pyi @@ -2,9 +2,12 @@ from typing import Any import numpy as np import numpy.ma +import numpy.typing as npt m: np.ma.MaskedArray[tuple[int], np.dtype[np.float64]] +AR_b: npt.NDArray[np.bool] + m.shape = (3, 1) # E: Incompatible types in assignment m.dtype = np.bool # E: Incompatible types in assignment @@ -68,3 +71,15 @@ m.sort(endwith='cabbage') # E: No overload variant m.sort(fill_value=lambda: 'cabbage') # E: No overload variant m.sort(stable='cabbage') # E: No overload variant m.sort(stable=True) # E: No overload variant + +m.partition(['cabbage']) # E: No overload variant +m.partition(axis=(0,1)) # E: No overload variant +m.partition(kind='cabbage') # E: No overload variant +m.partition(order=lambda: 'cabbage') # E: No overload variant +m.partition(AR_b) # E: No overload variant + +m.argpartition(['cabbage']) # E: No overload variant +m.argpartition(axis=(0,1)) # E: No overload variant +m.argpartition(kind='cabbage') # E: No overload variant +m.argpartition(order=lambda: 'cabbage') # E: No overload variant +m.argpartition(AR_b) # E: No overload variant diff --git a/numpy/typing/tests/data/reveal/ma.pyi b/numpy/typing/tests/data/reveal/ma.pyi index baec77268676..52bff772a673 100644 --- a/numpy/typing/tests/data/reveal/ma.pyi +++ b/numpy/typing/tests/data/reveal/ma.pyi @@ -128,3 +128,9 @@ assert_type(np.ma.sort(MAR_f4), MaskedNDArray[np.float32]) assert_type(np.ma.sort(MAR_subclass), MaskedNDArraySubclass) assert_type(np.ma.sort([[0, 1], [2, 3]]), NDArray[Any]) assert_type(np.ma.sort(AR_f4), NDArray[np.float32]) + +assert_type(MAR_f4.partition(1), None) +assert_type(MAR_f4.partition(1, axis=0, kind='introselect', order='K'), None) + +assert_type(MAR_f4.argpartition(1), MaskedNDArray[np.intp]) +assert_type(MAR_1d.argpartition(1, axis=0, kind='introselect', order='K'), MaskedNDArray[np.intp])