Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH: Add array API standard v2022.12 support to numpy.array_api #23881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add isdtype() to numpy.array_api
This is a new function in the v2022.12 version of the array API standard which
is used for determining if a given dtype is part of a set of given dtype
categories. This will also eventually be added to the main NumPy namespace,
but for now only exists in numpy.array_api as a purely strict version.
  • Loading branch information
asmeurer committed Jun 6, 2023
commit 173fbc7009719ce802aa70634fb93031a0c00cfb
1 change: 1 addition & 0 deletions numpy/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
broadcast_to,
can_cast,
finfo,
isdtype,
iinfo,
result_type,
)
Expand Down
50 changes: 49 additions & 1 deletion numpy/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from __future__ import annotations

from ._array_object import Array
from ._dtypes import _all_dtypes, _result_type
from ._dtypes import (
_all_dtypes,
_boolean_dtypes,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
_integer_dtypes,
_real_floating_dtypes,
_complex_floating_dtypes,
_numeric_dtypes,
_result_type,
)

from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Tuple, Union
Expand Down Expand Up @@ -117,6 +127,44 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
return iinfo_object(ii.bits, ii.max, ii.min)


# Note: isdtype is a new function from the 2022.12 array API specification.
def isdtype(
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]]
) -> bool:
"""
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.

See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
for more details
"""
if isinstance(kind, tuple):
# Disallow nested tuples
if any(isinstance(k, tuple) for k in kind):
raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs")
return any(isdtype(dtype, k) for k in kind)
elif isinstance(kind, str):
if kind == 'bool':
return dtype in _boolean_dtypes
elif kind == 'signed integer':
return dtype in _signed_integer_dtypes
elif kind == 'unsigned integer':
return dtype in _unsigned_integer_dtypes
elif kind == 'integral':
return dtype in _integer_dtypes
elif kind == 'real floating':
return dtype in _real_floating_dtypes
elif kind == 'complex floating':
return dtype in _complex_floating_dtypes
elif kind == 'numeric':
return dtype in _numeric_dtypes
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
elif kind in _all_dtypes:
return dtype == kind
else:
raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}")

def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
"""
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
Expand Down
2 changes: 2 additions & 0 deletions numpy/array_api/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
_floating_dtypes = (float32, float64, complex64, complex128)
_complex_floating_dtypes = (complex64, complex128)
_integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64)
_signed_integer_dtypes = (int8, int16, int32, int64)
_unsigned_integer_dtypes = (uint8, uint16, uint32, uint64)
_integer_or_boolean_dtypes = (
bool,
int8,
Expand Down
14 changes: 13 additions & 1 deletion numpy/array_api/tests/test_data_type_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

from numpy.testing import assert_raises
from numpy import array_api as xp

import numpy as np

@pytest.mark.parametrize(
"from_, to, expected",
Expand All @@ -17,3 +18,14 @@ def test_can_cast(from_, to, expected):
can_cast() returns correct result
"""
assert xp.can_cast(from_, to) == expected

def test_isdtype_strictness():
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, 64))
assert_raises(ValueError, lambda: xp.isdtype(xp.float64, 'f8'))

assert_raises(TypeError, lambda: xp.isdtype(xp.float64, (('integral',),)))
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, np.object_))

# TODO: These will require https://github.com/numpy/numpy/issues/23883
# assert_raises(TypeError, lambda: xp.isdtype(xp.float64, None))
# assert_raises(TypeError, lambda: xp.isdtype(xp.float64, np.float64))