diff --git a/src/numpy-stubs/lib/_type_check_impl.pyi b/src/numpy-stubs/lib/_type_check_impl.pyi index 00f03f00..7a372332 100644 --- a/src/numpy-stubs/lib/_type_check_impl.pyi +++ b/src/numpy-stubs/lib/_type_check_impl.pyi @@ -1,9 +1,23 @@ from collections.abc import Container, Iterable -from typing import Any, Literal as L, overload -from typing_extensions import TypeVar +from typing import Any, Literal as L, overload, type_check_only +from typing_extensions import Protocol, TypeVar import numpy as np -from numpy._typing import ArrayLike, NBitBase, NDArray, _64Bit, _ArrayLike, _ScalarLike_co, _SupportsDType +from _numtype import ( + Array, + ToBool_nd, + ToBytes_nd, + ToCLongDouble_nd, + ToComplex64_nd, + ToComplex128_nd, + ToFloat64_nd, + ToGeneric_0d, + ToGeneric_1nd, + ToIntP_nd, + ToStr_nd, + _ToArray1_1nd, +) +from numpy._typing import ArrayLike, _ArrayLike __all__ = [ "common_type", @@ -19,87 +33,136 @@ __all__ = [ "typename", ] +### + _T = TypeVar("_T") -_SCT = TypeVar("_SCT", bound=np.generic) -_NBit = TypeVar("_NBit", bound=NBitBase) -_NBit2 = TypeVar("_NBit2", bound=NBitBase) +_T_co = TypeVar("_T_co", covariant=True) +_ScalarT = TypeVar("_ScalarT", bound=np.generic) +_ScalarT_co = TypeVar("_ScalarT_co", bound=np.generic, covariant=True) + +@type_check_only +class _HasReal(Protocol[_T_co]): + @property + def real(self, /) -> _T_co: ... + +@type_check_only +class _HasImag(Protocol[_T_co]): + @property + def imag(self, /) -> _T_co: ... -def mintypecode(typechars: Iterable[str | ArrayLike], typeset: Container[str] = ..., default: str = ...) -> str: ... +@type_check_only +class _HasDType(Protocol[_ScalarT_co]): + @property + def dtype(self, /) -> np.dtype[_ScalarT_co]: ... + +### + +# +def mintypecode(typechars: Iterable[str | ArrayLike], typeset: str | Container[str] = "GDFgdf", default: str = "d") -> str: ... # @overload -def real(val: np._HasRealAndImag[_T, Any]) -> _T: ... +def real(val: _HasReal[_T]) -> _T: ... # type: ignore[overload-overlap] +@overload +def real(val: ToBool_nd) -> Array[np.bool]: ... +@overload +def real(val: ToIntP_nd) -> Array[np.intp]: ... +@overload +def real(val: ToFloat64_nd) -> Array[np.float64]: ... +@overload +def real(val: ToComplex128_nd) -> Array[np.complex128]: ... +@overload +def real(val: ToBytes_nd) -> Array[np.bytes_]: ... +@overload +def real(val: ToStr_nd) -> Array[np.str_]: ... @overload -def real(val: ArrayLike) -> NDArray[Any]: ... +def real(val: _ArrayLike[_ScalarT]) -> Array[_ScalarT]: ... +@overload +def real(val: ArrayLike) -> Array[Any]: ... # @overload -def imag(val: np._HasRealAndImag[Any, _T]) -> _T: ... +def imag(val: _HasImag[_T]) -> _T: ... # type: ignore[overload-overlap] +@overload +def imag(val: ToBool_nd) -> Array[np.bool]: ... +@overload +def imag(val: ToIntP_nd) -> Array[np.intp]: ... +@overload +def imag(val: ToFloat64_nd) -> Array[np.float64]: ... +@overload +def imag(val: ToComplex128_nd) -> Array[np.complex128]: ... +@overload +def imag(val: ToBytes_nd) -> Array[np.bytes_]: ... @overload -def imag(val: ArrayLike) -> NDArray[Any]: ... +def imag(val: ToStr_nd) -> Array[np.str_]: ... +@overload +def imag(val: _ArrayLike[_ScalarT]) -> Array[_ScalarT]: ... +@overload +def imag(val: ArrayLike) -> Array[Any]: ... # @overload -def iscomplex(x: _ScalarLike_co) -> np.bool: ... +def iscomplex(x: ToGeneric_0d) -> np.bool: ... # type: ignore[overload-overlap] @overload -def iscomplex(x: ArrayLike) -> NDArray[np.bool]: ... +def iscomplex(x: ToGeneric_1nd) -> Array[np.bool]: ... # @overload -def isreal(x: _ScalarLike_co) -> np.bool: ... +def isreal(x: ToGeneric_0d) -> np.bool: ... # type: ignore[overload-overlap] @overload -def isreal(x: ArrayLike) -> NDArray[np.bool]: ... +def isreal(x: ToGeneric_1nd) -> Array[np.bool]: ... # -def iscomplexobj(x: _SupportsDType[np.dtype[Any]] | ArrayLike) -> bool: ... -def isrealobj(x: _SupportsDType[np.dtype[Any]] | ArrayLike) -> bool: ... +def iscomplexobj(x: _HasDType[Any] | ArrayLike) -> bool: ... +def isrealobj(x: _HasDType[Any] | ArrayLike) -> bool: ... # @overload def nan_to_num( - x: _SCT, - copy: bool = ..., - nan: float = ..., - posinf: float | None = ..., - neginf: float | None = ..., -) -> _SCT: ... + x: _ScalarT, + copy: bool = True, + nan: float = 0.0, + posinf: float | None = None, + neginf: float | None = None, +) -> _ScalarT: ... @overload def nan_to_num( - x: _ScalarLike_co, - copy: bool = ..., - nan: float = ..., - posinf: float | None = ..., - neginf: float | None = ..., -) -> Any: ... + x: _ToArray1_1nd[_ScalarT], + copy: bool = True, + nan: float = 0.0, + posinf: float | None = None, + neginf: float | None = None, +) -> Array[_ScalarT]: ... @overload def nan_to_num( - x: _ArrayLike[_SCT], - copy: bool = ..., - nan: float = ..., - posinf: float | None = ..., - neginf: float | None = ..., -) -> NDArray[_SCT]: ... + x: ToGeneric_0d, + copy: bool = True, + nan: float = 0.0, + posinf: float | None = None, + neginf: float | None = None, +) -> Any: ... @overload def nan_to_num( - x: ArrayLike, - copy: bool = ..., - nan: float = ..., - posinf: float | None = ..., - neginf: float | None = ..., -) -> NDArray[Any]: ... + x: ToGeneric_1nd, + copy: bool = True, + nan: float = 0.0, + posinf: float | None = None, + neginf: float | None = None, +) -> Array[Any]: ... # If one passes a complex array to `real_if_close`, then one is reasonably # expected to verify the output dtype (so we can return an unsafe union here) @overload -def real_if_close( - a: _ArrayLike[np.complexfloating[_NBit]], - tol: float = 100, -) -> NDArray[np.floating[_NBit]] | NDArray[np.complexfloating[_NBit]]: ... +def real_if_close(a: ToComplex128_nd, tol: float = 100) -> Array[np.float64 | np.complex128]: ... +@overload +def real_if_close(a: ToComplex64_nd, tol: float = 100) -> Array[np.float32 | np.complex64]: ... +@overload +def real_if_close(a: ToCLongDouble_nd, tol: float = 100) -> Array[np.longdouble | np.clongdouble]: ... @overload -def real_if_close(a: _ArrayLike[_SCT], tol: float = 100) -> NDArray[_SCT]: ... +def real_if_close(a: _ArrayLike[_ScalarT], tol: float = 100) -> Array[_ScalarT]: ... @overload -def real_if_close(a: ArrayLike, tol: float = 100) -> NDArray[Any]: ... +def real_if_close(a: ArrayLike, tol: float = 100) -> Array[Any]: ... # @overload @@ -147,24 +210,159 @@ def typename(char: L["V"]) -> L["void"]: ... @overload def typename(char: L["O"]) -> L["object"]: ... -# +# NOTE: both mypy and pyright are report false-positive overlapping overloads (I think) +@overload +def common_type() -> type[np.float16]: ... +@overload +def common_type( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + array0: _HasDType[np.float16], + /, + *arrays: _HasDType[np.float16], +) -> type[np.float16]: ... +@overload +def common_type( # type: ignore[overload-overlap] + array0: _HasDType[np.float32], + /, + *arrays: _HasDType[np.float32 | np.float16], +) -> type[np.float32]: ... +@overload +def common_type( # pyright: ignore[reportOverlappingOverload] + array0: _HasDType[np.double | np.integer], + /, + *arrays: _HasDType[np.double | np.float32 | np.float16 | np.integer], +) -> type[np.float64]: ... +@overload +def common_type( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + array0: _HasDType[np.longdouble], + /, + *arrays: _HasDType[np.floating | np.integer], +) -> type[np.longdouble]: ... +@overload +def common_type( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + array0: _HasDType[np.complex64], + /, + *arrays: _HasDType[np.complex64 | np.float32 | np.float16], +) -> type[np.complex64]: ... +@overload +def common_type( # type: ignore[overload-overlap] + array0: _HasDType[np.cdouble], + /, + *arrays: _HasDType[np.cdouble | np.complex64 | np.double | np.float32 | np.float16 | np.integer], +) -> type[np.complex128]: ... +@overload +def common_type( + array0: _HasDType[np.clongdouble], + /, + *arrays: _HasDType[np.number], +) -> type[np.clongdouble]: ... +@overload +def common_type( # type: ignore[overload-overlap] + array0: _HasDType[np.float32 | np.float16], + array1: _HasDType[np.float32], + /, + *arrays: _HasDType[np.float32 | np.float16], +) -> type[np.float32]: ... @overload def common_type( - *arrays: _SupportsDType[np.dtype[np.integer]], + array0: _HasDType[np.double | np.float32 | np.float16 | np.integer], + array1: _HasDType[np.double | np.integer], + /, + *arrays: _HasDType[np.double | np.float32 | np.float16 | np.integer], ) -> type[np.float64]: ... @overload def common_type( - *arrays: _SupportsDType[np.dtype[np.floating[_NBit]]], -) -> type[np.floating[_NBit]]: ... + array0: _HasDType[np.floating | np.integer], + array1: _HasDType[np.longdouble], + /, + *arrays: _HasDType[np.floating | np.integer], +) -> type[np.longdouble]: ... +@overload +def common_type( # type: ignore[overload-overlap] + array0: _HasDType[np.complex64 | np.float32 | np.float16], + array1: _HasDType[np.complex64], + /, + *arrays: _HasDType[np.complex64 | np.float32 | np.float16], +) -> type[np.complex64]: ... +@overload +def common_type( + array0: _HasDType[np.double], + array1: _HasDType[np.cdouble | np.complex128 | np.complex64], + /, + *arrays: _HasDType[np.cdouble | np.complex64 | np.double | np.float32 | np.float16 | np.integer], +) -> type[np.complex128]: ... +@overload +def common_type( + array0: _HasDType[np.cdouble | np.complex128 | np.complex64], + array1: _HasDType[np.double], + /, + *arrays: _HasDType[np.cdouble | np.complex64 | np.double | np.float32 | np.float16 | np.integer], +) -> type[np.complex128]: ... +@overload +def common_type( + array0: _HasDType[np.cdouble | np.complex128 | np.complex64 | np.double | np.float32 | np.float16 | np.integer], + array1: _HasDType[np.cdouble], + /, + *arrays: _HasDType[np.cdouble | np.complex64 | np.double | np.float32 | np.float16 | np.integer], +) -> type[np.complex128]: ... +@overload +def common_type( + array0: _HasDType[np.cdouble | np.complex128 | np.complex64], + array1: _HasDType[np.cdouble | np.integer], + /, + *arrays: _HasDType[np.cdouble | np.complex64 | np.double | np.float32 | np.float16 | np.integer], +) -> type[np.complex128]: ... +@overload +def common_type( + array0: _HasDType[np.cdouble | np.integer], + array1: _HasDType[np.cdouble | np.complex128 | np.complex64], + /, + *arrays: _HasDType[np.cdouble | np.complex64 | np.double | np.float32 | np.float16 | np.integer], +) -> type[np.complex128]: ... +@overload +def common_type( # pyright: ignore[reportOverlappingOverload] + array0: _HasDType[np.floating | np.integer], + /, + *arrays: _HasDType[np.floating | np.integer], +) -> type[np.floating]: ... +@overload +def common_type( + array0: _HasDType[np.number], + array1: _HasDType[np.clongdouble], + /, + *arrays: _HasDType[np.number], +) -> type[np.clongdouble]: ... +@overload +def common_type( + array0: _HasDType[np.longdouble], + array1: _HasDType[np.complexfloating], + /, + *arrays: _HasDType[np.number], +) -> type[np.clongdouble]: ... +@overload +def common_type( + array0: _HasDType[np.complexfloating], + array1: _HasDType[np.longdouble], + /, + *arrays: _HasDType[np.number], +) -> type[np.clongdouble]: ... @overload def common_type( - *arrays: _SupportsDType[np.dtype[np.integer | np.floating[_NBit]]], -) -> type[np.floating[_NBit | _64Bit]]: ... + array0: _HasDType[np.complexfloating], + array1: _HasDType[np.number], + /, + *arrays: _HasDType[np.number], +) -> type[np.complexfloating]: ... @overload def common_type( - *arrays: _SupportsDType[np.dtype[np.floating[_NBit] | np.complexfloating[_NBit2]]], -) -> type[np.complexfloating[_NBit | _NBit2]]: ... + array0: _HasDType[np.number], + array1: _HasDType[np.complexfloating], + /, + *arrays: _HasDType[np.number], +) -> type[np.complexfloating]: ... @overload def common_type( - *arrays: _SupportsDType[np.dtype[np.integer | np.floating[_NBit] | np.complexfloating[_NBit2]]], -) -> type[np.complexfloating[_NBit | _NBit2 | _64Bit]]: ... + array0: _HasDType[np.number], + array1: _HasDType[np.number], + /, + *arrays: _HasDType[np.number], +) -> type[np.inexact]: ... diff --git a/test/static/accept/type_check.pyi b/test/static/accept/type_check.pyi index 1948de63..51edab28 100644 --- a/test/static/accept/type_check.pyi +++ b/test/static/accept/type_check.pyi @@ -3,7 +3,6 @@ from typing_extensions import assert_type import numpy as np import numpy.typing as npt -from numpy._typing import _16Bit, _32Bit, _64Bit, _128Bit f8: np.float64 f: float @@ -13,11 +12,11 @@ AR_i8: npt.NDArray[np.int64] AR_i4: npt.NDArray[np.int32] AR_f2: npt.NDArray[np.float16] AR_f8: npt.NDArray[np.float64] -AR_f16: npt.NDArray[np.floating[_128Bit]] +AR_f16: npt.NDArray[np.longdouble] AR_c8: npt.NDArray[np.complex64] AR_c16: npt.NDArray[np.complex128] -AR_LIKE_f: list[float] +AR_LIKE_i: list[int] class ComplexObj: real: slice @@ -28,20 +27,20 @@ assert_type(np.mintypecode(["f8"], typeset="qfQF"), str) assert_type(np.real(ComplexObj()), slice) assert_type(np.real(AR_f8), npt.NDArray[np.float64]) assert_type(np.real(AR_c16), npt.NDArray[np.float64]) -assert_type(np.real(AR_LIKE_f), npt.NDArray[Any]) +assert_type(np.real(AR_LIKE_i), npt.NDArray[np.intp]) assert_type(np.imag(ComplexObj()), slice) assert_type(np.imag(AR_f8), npt.NDArray[np.float64]) assert_type(np.imag(AR_c16), npt.NDArray[np.float64]) -assert_type(np.imag(AR_LIKE_f), npt.NDArray[Any]) +assert_type(np.imag(AR_LIKE_i), npt.NDArray[np.intp]) assert_type(np.iscomplex(f8), np.bool) assert_type(np.iscomplex(AR_f8), npt.NDArray[np.bool]) -assert_type(np.iscomplex(AR_LIKE_f), npt.NDArray[np.bool]) +assert_type(np.iscomplex(AR_LIKE_i), npt.NDArray[np.bool]) assert_type(np.isreal(f8), np.bool) assert_type(np.isreal(AR_f8), npt.NDArray[np.bool]) -assert_type(np.isreal(AR_LIKE_f), npt.NDArray[np.bool]) +assert_type(np.isreal(AR_LIKE_i), npt.NDArray[np.bool]) assert_type(np.iscomplexobj(f8), bool) assert_type(np.isrealobj(f8), bool) @@ -49,15 +48,12 @@ assert_type(np.isrealobj(f8), bool) assert_type(np.nan_to_num(f8), np.float64) assert_type(np.nan_to_num(f, copy=True), Any) assert_type(np.nan_to_num(AR_f8, nan=1.5), npt.NDArray[np.float64]) -assert_type(np.nan_to_num(AR_LIKE_f, posinf=9999), npt.NDArray[Any]) +assert_type(np.nan_to_num(AR_LIKE_i, posinf=9999), npt.NDArray[Any]) assert_type(np.real_if_close(AR_f8), npt.NDArray[np.float64]) -assert_type( - np.real_if_close(AR_c16), - npt.NDArray[np.floating[_64Bit]] | npt.NDArray[np.complexfloating[_64Bit]], -) -assert_type(np.real_if_close(AR_c8), npt.NDArray[np.float32] | npt.NDArray[np.complex64]) -assert_type(np.real_if_close(AR_LIKE_f), npt.NDArray[Any]) +assert_type(np.real_if_close(AR_c8), npt.NDArray[np.float32 | np.complex64]) +assert_type(np.real_if_close(AR_c16), npt.NDArray[np.float64 | np.complex128]) +assert_type(np.real_if_close(AR_LIKE_i), npt.NDArray[Any]) assert_type(np.typename("h"), Literal["short"]) assert_type(np.typename("B"), Literal["unsigned char"]) @@ -66,13 +62,7 @@ assert_type(np.typename("S1"), Literal["character"]) assert_type(np.common_type(AR_i4), type[np.float64]) assert_type(np.common_type(AR_f2), type[np.float16]) -assert_type(np.common_type(AR_f2, AR_i4), type[np.floating[_16Bit | _64Bit]]) -assert_type(np.common_type(AR_f16, AR_i4), type[np.floating[_64Bit | _128Bit]]) -assert_type( - np.common_type(AR_c8, AR_f2), - type[np.complexfloating[_16Bit | _32Bit]], -) -assert_type( - np.common_type(AR_f2, AR_c8, AR_i4), - type[np.complexfloating[_16Bit | _32Bit | _64Bit]], -) +assert_type(np.common_type(AR_f2, AR_i4), type[np.float64]) +assert_type(np.common_type(AR_f16, AR_i4), type[np.longdouble]) +assert_type(np.common_type(AR_c8, AR_f2), type[np.complex64]) +assert_type(np.common_type(AR_f2, AR_c8, AR_i4), type[np.complexfloating])