Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c178bac

Browse files
authored
Merge pull request #23881 from asmeurer/array_api-2022
ENH: Add array API standard v2022.12 support to numpy.array_api
2 parents 26edc98 + 91153af commit c178bac

13 files changed

+332
-108
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Array API v2022.12 support in ``numpy.array_api``
2+
-------------------------------------------------
3+
4+
- ``numpy.array_api`` now full supports the `v2022.12 version
5+
<https://data-apis.org/array-api/2022.12>`__ of the array API standard. Note
6+
that this does not yet include the optional ``fft`` extension in the
7+
standard.

numpy/array_api/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
"The numpy.array_api submodule is still experimental. See NEP 47.", stacklevel=2
122122
)
123123

124-
__array_api_version__ = "2021.12"
124+
__array_api_version__ = "2022.12"
125125

126126
__all__ = ["__array_api_version__"]
127127

@@ -173,6 +173,7 @@
173173
broadcast_to,
174174
can_cast,
175175
finfo,
176+
isdtype,
176177
iinfo,
177178
result_type,
178179
)
@@ -198,6 +199,8 @@
198199
uint64,
199200
float32,
200201
float64,
202+
complex64,
203+
complex128,
201204
bool,
202205
)
203206

@@ -232,6 +235,7 @@
232235
bitwise_right_shift,
233236
bitwise_xor,
234237
ceil,
238+
conj,
235239
cos,
236240
cosh,
237241
divide,
@@ -242,6 +246,7 @@
242246
floor_divide,
243247
greater,
244248
greater_equal,
249+
imag,
245250
isfinite,
246251
isinf,
247252
isnan,
@@ -261,6 +266,7 @@
261266
not_equal,
262267
positive,
263268
pow,
269+
real,
264270
remainder,
265271
round,
266272
sign,

numpy/array_api/_array_object.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_integer_dtypes,
2525
_integer_or_boolean_dtypes,
2626
_floating_dtypes,
27+
_complex_floating_dtypes,
2728
_numeric_dtypes,
2829
_result_type,
2930
_dtype_categories,
@@ -139,7 +140,7 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor
139140

140141
if self.dtype not in _dtype_categories[dtype_category]:
141142
raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
142-
if isinstance(other, (int, float, bool)):
143+
if isinstance(other, (int, complex, float, bool)):
143144
other = self._promote_scalar(other)
144145
elif isinstance(other, Array):
145146
if other.dtype not in _dtype_categories[dtype_category]:
@@ -189,11 +190,23 @@ def _promote_scalar(self, scalar):
189190
raise TypeError(
190191
"Python int scalars cannot be promoted with bool arrays"
191192
)
193+
if self.dtype in _integer_dtypes:
194+
info = np.iinfo(self.dtype)
195+
if not (info.min <= scalar <= info.max):
196+
raise OverflowError(
197+
"Python int scalars must be within the bounds of the dtype for integer arrays"
198+
)
199+
# int + array(floating) is allowed
192200
elif isinstance(scalar, float):
193201
if self.dtype not in _floating_dtypes:
194202
raise TypeError(
195203
"Python float scalars can only be promoted with floating-point arrays."
196204
)
205+
elif isinstance(scalar, complex):
206+
if self.dtype not in _complex_floating_dtypes:
207+
raise TypeError(
208+
"Python complex scalars can only be promoted with complex floating-point arrays."
209+
)
197210
else:
198211
raise TypeError("'scalar' must be a Python scalar")
199212

@@ -454,11 +467,19 @@ def __bool__(self: Array, /) -> bool:
454467
# Note: This is an error here.
455468
if self._array.ndim != 0:
456469
raise TypeError("bool is only allowed on arrays with 0 dimensions")
457-
if self.dtype not in _boolean_dtypes:
458-
raise ValueError("bool is only allowed on boolean arrays")
459470
res = self._array.__bool__()
460471
return res
461472

473+
def __complex__(self: Array, /) -> complex:
474+
"""
475+
Performs the operation __complex__.
476+
"""
477+
# Note: This is an error here.
478+
if self._array.ndim != 0:
479+
raise TypeError("complex is only allowed on arrays with 0 dimensions")
480+
res = self._array.__complex__()
481+
return res
482+
462483
def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule:
463484
"""
464485
Performs the operation __dlpack__.
@@ -492,16 +513,16 @@ def __float__(self: Array, /) -> float:
492513
# Note: This is an error here.
493514
if self._array.ndim != 0:
494515
raise TypeError("float is only allowed on arrays with 0 dimensions")
495-
if self.dtype not in _floating_dtypes:
496-
raise ValueError("float is only allowed on floating-point arrays")
516+
if self.dtype in _complex_floating_dtypes:
517+
raise TypeError("float is not allowed on complex floating-point arrays")
497518
res = self._array.__float__()
498519
return res
499520

500521
def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
501522
"""
502523
Performs the operation __floordiv__.
503524
"""
504-
other = self._check_allowed_dtypes(other, "numeric", "__floordiv__")
525+
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
505526
if other is NotImplemented:
506527
return other
507528
self, other = self._normalize_two_args(self, other)
@@ -512,7 +533,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
512533
"""
513534
Performs the operation __ge__.
514535
"""
515-
other = self._check_allowed_dtypes(other, "numeric", "__ge__")
536+
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
516537
if other is NotImplemented:
517538
return other
518539
self, other = self._normalize_two_args(self, other)
@@ -542,7 +563,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
542563
"""
543564
Performs the operation __gt__.
544565
"""
545-
other = self._check_allowed_dtypes(other, "numeric", "__gt__")
566+
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
546567
if other is NotImplemented:
547568
return other
548569
self, other = self._normalize_two_args(self, other)
@@ -556,8 +577,8 @@ def __int__(self: Array, /) -> int:
556577
# Note: This is an error here.
557578
if self._array.ndim != 0:
558579
raise TypeError("int is only allowed on arrays with 0 dimensions")
559-
if self.dtype not in _integer_dtypes:
560-
raise ValueError("int is only allowed on integer arrays")
580+
if self.dtype in _complex_floating_dtypes:
581+
raise TypeError("int is not allowed on complex floating-point arrays")
561582
res = self._array.__int__()
562583
return res
563584

@@ -581,7 +602,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
581602
"""
582603
Performs the operation __le__.
583604
"""
584-
other = self._check_allowed_dtypes(other, "numeric", "__le__")
605+
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
585606
if other is NotImplemented:
586607
return other
587608
self, other = self._normalize_two_args(self, other)
@@ -603,7 +624,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
603624
"""
604625
Performs the operation __lt__.
605626
"""
606-
other = self._check_allowed_dtypes(other, "numeric", "__lt__")
627+
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
607628
if other is NotImplemented:
608629
return other
609630
self, other = self._normalize_two_args(self, other)
@@ -626,7 +647,7 @@ def __mod__(self: Array, other: Union[int, float, Array], /) -> Array:
626647
"""
627648
Performs the operation __mod__.
628649
"""
629-
other = self._check_allowed_dtypes(other, "numeric", "__mod__")
650+
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
630651
if other is NotImplemented:
631652
return other
632653
self, other = self._normalize_two_args(self, other)
@@ -808,7 +829,7 @@ def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
808829
"""
809830
Performs the operation __ifloordiv__.
810831
"""
811-
other = self._check_allowed_dtypes(other, "numeric", "__ifloordiv__")
832+
other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__")
812833
if other is NotImplemented:
813834
return other
814835
self._array.__ifloordiv__(other._array)
@@ -818,7 +839,7 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
818839
"""
819840
Performs the operation __rfloordiv__.
820841
"""
821-
other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__")
842+
other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__")
822843
if other is NotImplemented:
823844
return other
824845
self, other = self._normalize_two_args(self, other)
@@ -874,7 +895,7 @@ def __imod__(self: Array, other: Union[int, float, Array], /) -> Array:
874895
"""
875896
Performs the operation __imod__.
876897
"""
877-
other = self._check_allowed_dtypes(other, "numeric", "__imod__")
898+
other = self._check_allowed_dtypes(other, "real numeric", "__imod__")
878899
if other is NotImplemented:
879900
return other
880901
self._array.__imod__(other._array)
@@ -884,7 +905,7 @@ def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array:
884905
"""
885906
Performs the operation __rmod__.
886907
"""
887-
other = self._check_allowed_dtypes(other, "numeric", "__rmod__")
908+
other = self._check_allowed_dtypes(other, "real numeric", "__rmod__")
888909
if other is NotImplemented:
889910
return other
890911
self, other = self._normalize_two_args(self, other)

numpy/array_api/_data_type_functions.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
from __future__ import annotations
22

33
from ._array_object import Array
4-
from ._dtypes import _all_dtypes, _result_type
4+
from ._dtypes import (
5+
_all_dtypes,
6+
_boolean_dtypes,
7+
_signed_integer_dtypes,
8+
_unsigned_integer_dtypes,
9+
_integer_dtypes,
10+
_real_floating_dtypes,
11+
_complex_floating_dtypes,
12+
_numeric_dtypes,
13+
_result_type,
14+
)
515

616
from dataclasses import dataclass
717
from typing import TYPE_CHECKING, List, Tuple, Union
@@ -80,13 +90,15 @@ class finfo_object:
8090
max: float
8191
min: float
8292
smallest_normal: float
93+
dtype: Dtype
8394

8495

8596
@dataclass
8697
class iinfo_object:
8798
bits: int
8899
max: int
89100
min: int
101+
dtype: Dtype
90102

91103

92104
def finfo(type: Union[Dtype, Array], /) -> finfo_object:
@@ -104,6 +116,7 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object:
104116
float(fi.max),
105117
float(fi.min),
106118
float(fi.smallest_normal),
119+
fi.dtype,
107120
)
108121

109122

@@ -114,9 +127,47 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
114127
See its docstring for more information.
115128
"""
116129
ii = np.iinfo(type)
117-
return iinfo_object(ii.bits, ii.max, ii.min)
130+
return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype)
118131

119132

133+
# Note: isdtype is a new function from the 2022.12 array API specification.
134+
def isdtype(
135+
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]]
136+
) -> bool:
137+
"""
138+
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
139+
140+
See
141+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
142+
for more details
143+
"""
144+
if isinstance(kind, tuple):
145+
# Disallow nested tuples
146+
if any(isinstance(k, tuple) for k in kind):
147+
raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs")
148+
return any(isdtype(dtype, k) for k in kind)
149+
elif isinstance(kind, str):
150+
if kind == 'bool':
151+
return dtype in _boolean_dtypes
152+
elif kind == 'signed integer':
153+
return dtype in _signed_integer_dtypes
154+
elif kind == 'unsigned integer':
155+
return dtype in _unsigned_integer_dtypes
156+
elif kind == 'integral':
157+
return dtype in _integer_dtypes
158+
elif kind == 'real floating':
159+
return dtype in _real_floating_dtypes
160+
elif kind == 'complex floating':
161+
return dtype in _complex_floating_dtypes
162+
elif kind == 'numeric':
163+
return dtype in _numeric_dtypes
164+
else:
165+
raise ValueError(f"Unrecognized data type kind: {kind!r}")
166+
elif kind in _all_dtypes:
167+
return dtype == kind
168+
else:
169+
raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}")
170+
120171
def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
121172
"""
122173
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.

0 commit comments

Comments
 (0)