From a5025d3b4312e5275b3cf091c2e61e7cc614d94e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 29 Mar 2024 17:05:14 -0600 Subject: [PATCH 001/252] Add basic documentation, based on the README The documentation configuration is based on the configuration from array-api-compat. Fixes #27 --- .github/workflows/docs-build.yml | 22 ++++ .github/workflows/docs-deploy.yml | 30 +++++ docs/Makefile | 23 ++++ docs/_static/custom.css | 12 ++ docs/_static/favicon.png | Bin 0 -> 5152 bytes docs/conf.py | 84 ++++++++++++ docs/index.md | 206 ++++++++++++++++++++++++++++++ docs/make.bat | 35 +++++ docs/requirements.txt | 6 + 9 files changed, 418 insertions(+) create mode 100644 .github/workflows/docs-build.yml create mode 100644 .github/workflows/docs-deploy.yml create mode 100644 docs/Makefile create mode 100644 docs/_static/custom.css create mode 100644 docs/_static/favicon.png create mode 100644 docs/conf.py create mode 100644 docs/index.md create mode 100644 docs/make.bat create mode 100644 docs/requirements.txt diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml new file mode 100644 index 0000000..04c3aa6 --- /dev/null +++ b/.github/workflows/docs-build.yml @@ -0,0 +1,22 @@ +name: Docs Build + +on: [push, pull_request] + +jobs: + docs-build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + - name: Install Dependencies + run: | + python -m pip install -r docs/requirements.txt + - name: Build Docs + run: | + cd docs + make html + - name: Upload Artifact + uses: actions/upload-artifact@v4 + with: + name: docs-build + path: docs/_build/html diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml new file mode 100644 index 0000000..7560028 --- /dev/null +++ b/.github/workflows/docs-deploy.yml @@ -0,0 +1,30 @@ +name: Docs Deploy + +on: + push: + branches: + - main + +jobs: + docs-deploy: + runs-on: ubuntu-latest + environment: + name: docs-deploy + steps: + - uses: actions/checkout@v4 + - name: Download Artifact + uses: dawidd6/action-download-artifact@v2 + with: + workflow: docs-build.yml + name: docs-build + path: docs/_build/html + + # Note, the gh-pages deployment requires setting up a SSH deploy key. + # See + # https://github.com/JamesIves/github-pages-deploy-action/tree/dev#using-an-ssh-deploy-key- + - name: Deploy + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: docs/_build/html + ssh-key: ${{ secrets.DEPLOY_KEY }} + force: no diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..11356c4 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,23 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +livehtml: + sphinx-autobuild --open-browser --watch .. --port 0 -b html $(SOURCEDIR) $(ALLSPHINXOPTS) $(BUILDDIR)/html diff --git a/docs/_static/custom.css b/docs/_static/custom.css new file mode 100644 index 0000000..bac0498 --- /dev/null +++ b/docs/_static/custom.css @@ -0,0 +1,12 @@ +/* Makes the text look better on Mac retina displays (the Furo CSS disables*/ +/* subpixel antialiasing). */ +body { + -webkit-font-smoothing: auto; + -moz-osx-font-smoothing: auto; +} + +/* Disable the fancy scrolling behavior when jumping to headers (this is too + slow for long pages) */ +html { + scroll-behavior: auto; +} diff --git a/docs/_static/favicon.png b/docs/_static/favicon.png new file mode 100644 index 0000000000000000000000000000000000000000..49b7d9d6fa9e82e24c907c324f53df2dbf11a222 GIT binary patch literal 5152 zcmbtY_fr#quSW$bLueVLBCBOjt$?)1EVNA72(py2q_X!EW$#g;>?Io!s0`W277$Rj z4A~$BL}bY3`R=_x;QQ|7F8SQ$l1nb*ha^fzOXU_dD>WGz*)26yWxapc{g1l$>)~zbW<#cE?ToeIQggJlv(dA$wDxuFw~-+uW9(E@Ry6RL+0Ld6 zVD14&Xj4bllBofU$OT6q0z!p?4Y;{lrMN@20Ec@z$*g{H3O|91dkTqQlx`FCsK!sj z*F2$#U!lGmUs|tm{|rfPZT7!f*!?K}qVfB@`2NoD_OCAt*Z!sQA5ygyo2I9Pa768T zE#fYB+4s_Bfgen9b%pc3ja@_wYUn}tpAP)++X}Jw2<;&aruk^rjuTBBoHa%2#A>ID z_^aQS&H{^^yo1;sAhyv!9Ugt<()Q!uXR{d8AQGKrul~lzf6KBc%1jzV(B~1Xbkg5| z?etPEv%kPlTJfIY@;#?iB^$`e3&e0HSLhD)b1s%dPL}wsh07oyHxg@}x*~JQ3$lp! zv1aQn^i)N;coC->Lv8;9S3+_%xH z0%+Qj2;}e&wiSIKtkHOTD%q{!(Q=)GnW|R*k#{!R3L-A=#S$1kbLvdg0D;e#S3Q7z z4RRZmYT%`{HzT*=_u7$d?oeet<7s)(5^;glW)>Db0@G|fI@YC$!7Fx^fIZvG_kwnI zUQ#3j8t3Gf=Hjffq$#WdB37D1rkY#>oKcZdcPmWHblmI;o<)Pl2eIcJ6#Pcf@M2~q z4G=@`-*m~gKxF+puI{bsPe!9rt4Vcrbt@IJTJZ_UTLpt33!UKvX-$L`H3)oMtZQVJ zP%&!cGNl=sx^kz3Xg>Ht7Ew>n)x7|3fBM8vepbXn7^R)t>bQU5>cs$e8tu6Q>qU?d zdlgW0+JZKRG2P;{7ng6-`N0!hcJ~o^-(l1nMq*~_X@1c z{XD=EWc{Wm#58J)!D4H~H4Ml!EycMOXjDnP*6_*9!@MIad-cKcE$OR16fb5&Ey1%f z_jTY{DK+8m{1K1$`hftdvq8OP6Vf*kVdeMrWmwv53HzGAGdYW~*!@v)kX^}h-rmLq zQLIP$xiZU{NMndXaAVsKqW{xsbjCx{`Jm%|JxX#?amF?H?8fTxaAK{<%3a?lAGh0YUuO7?YoQ z#><#~|LDiWMh%hMlDkDL9=tNZ7eZD&z7XukC z*KftDvTreSe`HsMZf?cEYgC*(M+Hc_H(-hVVUZ=z=zQ@cMa7UUw;!oJzpEMHm0!PU zP>%QK^~DZSU7j8M;t&G&xNB(ssRj&tadmRwH;|$#SDJq?pEh4Sm}L%>p%qTn$f;N_ z=907~-VIwAvUv&DQ}fG+H=F))!n4)Bqrv-UT$ku`&J(w3wMcNSl0G6>t*yF5X_*sG z#5T7!T~P4w8c6!KLTp6&edAXE;4(AQUZQ$ZE`jdnXJ1Zm*Xviel@rw{`SiD3+udF! zX_-yVi;+>{osrB@uFo>K8nsQvhJnCKw zI&o7(y9N;M?}ObgR*8~%Vs9Db8+=9)^9>+5Y76_!SZhE2@*IPIau#pq^L$)M%>sJf zoV<7dII%}uXqC0#*Yi%=KJ18zV3fvn9--5;*!e#*LzC&Uryi4EiQ&CAP9G@cqPBCY zW5;4^cMZ5UUU1!1R@DUwf$?ptlm_3Y40BNq?4r9H25%+GibQ=Qk8-m5PX0=8=`=q} z8Roaw%2=r4*~k_YA&HMKT6fN`dE@OJq_yTsm5+C_>h(XzZi$r9`+^DdD@=J238JuZ zw(UG4sI_f9Y2@l;XWR&R5#^Pgh6-6{|0>~Qvv1lB*f!0uex+t?@MQd_atznzJ(7%? zG6GFGj1%rm)53cd2%%R)5Me(%Rzi))y8wHrSRDe2V``)}2|{5^tL^Fa8saR}DFRl^ z=%n=I54H@`(JM_j%PP+9J1XU(&#Bt+;9isCVjgb{lrc=~dOMKUdsP5f&`s@!b*h~n zYBm(la*8@{C$bod@{J+D#gqw`zSeorFjawGwbQ-QNRWaO%XzVV_26J_G280*R^g-o zn2<=h1g$Dpr^BBd823c{*hv0N^t}%ug}Hi%%OCW@(LSw9WZ~fD3G8tV#148>*@qjX z$*U}5p8|=npIwqH*9jMG%D1;P5aF>0QUClN>+THZ&diRw7ty3u6Is1$kXAf)Dg`&0J3WkRm2&ViD%>80L{{>C5cDtLFIe$}m&NtoZ(ft#8xc-+`^+ zYV9|scV&}D>AOsz&ZctL8>Bd(x`F_!+%@(^YEr#7*qV|m3 z&lbVZ3_tv75!)VJz185TT&n+gpaPLv03GJ+O<(SYVI1kITMEXLGxdZG7#A~bOrk-W zt+wpz0lbfPCbRaBni#w0ZDJVmr$j(YE$cAa7rc+=ht0Gm5O*z^K+D7n$Ej|-T%1Qe z{nyi{kISilJ0{jAos~9t)DMoXR=PFUXAkwbcvC^seOt6&(ay@KJdDRRUj)^O^xE^| ze&Xt#c*qvB{vI*X-TupKHmtDO?8E6RUmI4k+xE7Gq0vkRHr_Cmz=NQmt#w!eo@L*8tzf;~RESc1Sa6~%6Q0SVNXY!~PpwcOr zqEyYMEcu0j&HmG(p);BeN>QvaAzLI|=8zBtq;5xr0`S)iZrqp>1O?O2U_$v)cGStw z7?($|5@q)zr{tF3g=Gb%j)q?0gu@GzN^$COzYd5btIrLzZvl08LpI)T1W;bJ&uINw zJ8OJ5lgekvS~~1@aN%JcftX!UbhmtiqnD~O$=PR}Yp;%kyB-QYs&7}jC4806#NFKcn|;?wrV5TQ-{8pUl#K&6EI+M3rT{Qy|Qx=sN1GpbF`ElW3wY zUsv(PWS9dC6@$w)Mf)VY%BI=F!ErA7>U^8ons27&9?TPubj4pjShHg0F8@umQ=w2r zLbEivFld}%E37{e88BD9N!ME5Q$iu31F5Kt^cT9cXb1%dW74Vn&_G@#3 zOEJ8El8gV9PVeOU#%8L{Mf=#?n1(IbUULUWw-&fN05e#&>VFOcSqDdxv9?)@gjs0R z-!{buNuQ4h6>rqC2F_!g79=->F}7n&!5U2vd>N9(F-5qSI-_1pFBPWxrqq1st|3x+ ze$euRaB$(fC!rHrRYi4YSp5D-$=qKkt8X-f^V{#q`-#PMf*E>LJ=2ObDC@*gy^Ot( zoxf#gCHTk1bajRH25-;;I235|>y*Y+>2%fiOc-a4)(L&FMCEGb%u<($TGcP{=zB?! z^wItb<~#O32$CcRN%z*Hvqf$eY`Jo_T+*rMc5s9OGX{@6gw#SZCYM^E zXysI?bjbZ-_iHITNM7A``Frp0#@zcB0JBE)*>lC6BzzI_Nqc;5y>C9+*Q`lUVB=Gv zEJJjxP8C*?d=}y_zcNt-@nhEUq)&s%%9lST3}Dv%E*A3vkdL?mlKZ9tc_0-?`)dN8 zPbXkKmrqFK&wBf6tZ_GjuEn1sAn=hRH2>#j$jKQ|vxY3Xo4V6}Aw;a9qbQ zEnV+7|44O}GV5e|C=lEh7HLe&9Vpg(J-z zk_M`uzq5QQfA0u5N&#g!E-D9dg^Mazn-NRio0blI(&Gs4WwR<@>2YyM6i;S3JT~Zl z{ssH&i_bt*bxQFSmv`Kx>-pWZwXa_cE!Uls8h^65|J(DE8AZ8U8Hp{bklP{ZSphJI zg7mF5WT#2|U*8(`x5$ya*Bvkn*^^)DfxX!A5QrLKDGczsV`IoKZQV)FnA+*>Vf44e zwu-w|q2qZEfh%>ymWaBH>kr6XWk%93Ty_fLWbJC0R(*b11yr4ae;krNFiq&rUg?sb zoZ^1(OL=ZM)W>wyS7~i*Kl*p`dG_vC{-sDi^dxof$D$EA3Zbiz!D~=juJD2tg$WKP zJzw%C#YFNyaZQbxYbzxGok!$nU{p2j?%W`XTq$zHY^hwRN3`F;vpQO^Qk+?A z=t}_$v9Cp$;)vWn;wLo24ta4GxgIl$5O(pxRW8!4k!3}OCT7VR1?t>XpN0>o?lC-< zeMttjP(0g%W`!H~{OrXhZ_LH@s5-iUa$Nmpg*qWyuGay3;aHI);nGS7a!zhZMm#2| z_ZduPT_U(L?CQ&jh+L(pfEggkXXZ4QfMFKBt)prV|Iqy4!-Jwxl3D0k7G~@a0D30D z*(b=v7pf{Z6Z?jGBBlq@eew)k5F*x79qE#ZHqJW*5H4;4VLjc6)Kam8TWnn?zyU%} zj5~dF+|aYJ3pbAz6}PFM3A3N&VRzIOL&DvDr2Ocgd$XU_QfP-{ zlVE~ojJ@*^cLQyR`?OStc7g7TQ_*j0^+{W)7xJYTlm@XYZV@K^TF fj4S;lm+S^vx7AGCc+NcFza7=kTFRBEXTkpkp8ADl literal 0 HcmV?d00001 diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..d6550e8 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,84 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +import sys +import os +sys.path.insert(0, os.path.abspath('..')) + +project = 'array-api-strict' +copyright = '2024, Consortium for Python Data API Standards' +author = 'Consortium for Python Data API Standards' + +import array_api_strict +release = array_api_strict.__version__ + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + 'myst_parser', + # 'sphinx.ext.autodoc', + # 'sphinx.ext.napoleon', + # 'sphinx.ext.intersphinx', + 'sphinx_copybutton', +] + +intersphinx_mapping = { +} +# Require :external: to reference intersphinx. +intersphinx_disabled_reftypes = ['*'] + +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +myst_enable_extensions = ["dollarmath", "linkify"] + +napoleon_use_rtype = False +napoleon_use_param = False + +# Make sphinx give errors for bad cross-references +nitpicky = True +# autodoc wants to make cross-references for every type hint. But a lot of +# them don't actually refer to anything that we have a document for. +nitpick_ignore = [ + ("py:class", "Array"), + ("py:class", "Device"), +] + +# Lets us use single backticks for code in RST +default_role = 'code' + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = 'furo' +html_static_path = ['_static'] + +html_css_files = ['custom.css'] + +html_theme_options = { + # See https://pradyunsg.me/furo/customisation/footer/ + "footer_icons": [ + { + "name": "GitHub", + "url": "https://github.com/data-apis/array-api-strict", + "html": """ + + + + """, + "class": "", + }, + ], +} + +# Logo + +html_favicon = "_static/favicon.png" + +# html_logo = "_static/logo.svg" diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..7c222e3 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,206 @@ +# array-api-strict + +`array_api_strict` is a strict, minimal implementation of the [Python array +API](https://data-apis.org/array-api/latest/) + +The purpose of array-api-strict is to provide an implementation of the array +API for consuming libraries to test against so they can be completely sure +their usage of the array API is portable. + +It is *not* intended to be used by end-users. End-users of the array API +should just use their favorite array library (NumPy, CuPy, PyTorch, etc.) as +usual. It is also not intended to be used as a dependency by consuming +libraries. Consuming library code should use the +[array-api-compat](https://github.com/data-apis/array-api-compat) package to +support the array API. Rather, it is intended to be used in the test suites of +consuming libraries to test their array API usage. + +array-api-strict currently supports the 2022.12 version of the standard. +2023.12 support is planned and is tracked by [this +issue](https://github.com/data-apis/array-api-strict/issues/25). + +## Install + +`array-api-strict` is available on both +[PyPI](https://pypi.org/project/array-api-strict/) + +``` +python -m pip install array-api-strict +``` + +and [Conda-forge](https://anaconda.org/conda-forge/array-api-strict) + +``` +conda install --channel conda-forge array-api-strict +``` + +array-api-strict supports NumPy 1.26 and (the upcoming) NumPy 2.0. + +## Rationale + +The array API has many functions and behaviors that are required to be +implemented by conforming libraries, but it does not, in most cases, disallow +implementing additional functions, keyword arguments, and behaviors that +aren't explicitly required by the standard. + +However, this poses a problem for consumers of the array API, as they may +accidentally use a function or rely on a behavior which just happens to be +implemented in every array library they test against (e.g., NumPy and +PyTorch), but isn't required by the standard and may not be included in other +libraries. + +array-api-strict solves this problem by providing a strict, minimal +implementation of the array API standard. Only those functions and behaviors +that are explicitly *required* by the standard are implemented. For example, +most NumPy functions accept Python scalars as inputs: + +```py +>>> import numpy as np +>>> np.sin(0.0) +0.0 +``` + +However, the standard only specifies function inputs on `Array` objects. And +indeed, some libraries, such as PyTorch, do not allow this: + +```py +>>> import torch +>>> torch.sin(0.0) +Traceback (most recent call last): + File "", line 1, in +TypeError: sin(): argument 'input' (position 1) must be Tensor, not float +``` + +In array-api-strict, this is also an error: + +```py +>>> import array_api_strict as xp +>>> xp.sin(0.0) +Traceback (most recent call last): +... +AttributeError: 'float' object has no attribute 'dtype' +``` + +Here is an (incomplete) list of the sorts of ways that array-api-strict is +strict/minimal: + +- Only those functions and methods that are [defined in the + standard](https://data-apis.org/array-api/latest/API_specification/index.html) + are included. + +- In those functions, only the keyword-arguments that are defined by the + standard are included. All signatures in array-api-strict use + [positional-only + arguments](https://data-apis.org/array-api/latest/API_specification/function_and_method_signatures.html#function-and-method-signatures). + As noted above, only `array_api_strict` array objects are accepted by + functions, except in the places where the standard allows Python scalars + (i.e., functions do not automatically call `asarray` on their inputs). + +- Only those [dtypes that are defined in the + standard](https://data-apis.org/array-api/latest/API_specification/data_types.html) + are included. + +- All functions and methods reject inputs if the standard does not *require* + the input dtype(s) to be supported. This is one of the most restrictive + aspects of the library. For example, in NumPy, most transcendental functions + like `sin` will accept integer array inputs, but the [standard only requires + them to accept floating-point + inputs](https://data-apis.org/array-api/latest/API_specification/generated/array_api.sin.html#array_api.sin), + so in array-api-strict, `sin(integer_array)` will raise an exception. + +- The + [indexing](https://data-apis.org/array-api/latest/API_specification/indexing.html) + semantics required by the standard are limited compared to those implemented + by NumPy (e.g., out-of-bounds slices are not supported, integer array + indexing is not supported, only a single boolean array index is supported). + +- There are no distinct "scalar" objects as in NumPy. There are only 0-D + arrays. + +- Dtype objects are just empty objects that only implement [equality + comparison](https://data-apis.org/array-api/latest/API_specification/generated/array_api.data_types.__eq__.html). + The way to access dtype objects in the standard is by name, like + `xp.float32`. + +- The array object type itself is private and should not be accessed. + Subclassing or otherwise trying to directly initialize this object is not + supported. Arrays should be created with one of the [array creation + functions](https://data-apis.org/array-api/latest/API_specification/creation_functions.html) + such as `asarray`. + +## Caveats + +array-api-strict is a thin pure Python wrapper around NumPy. NumPy 2.0 fully +supports the array API but NumPy 1.26 does not, so many behaviors are wrapped +in NumPy 1.26 to provide array API compatible behavior. Although it is based +on NumPy, mixing NumPy arrays with array-api-strict arrays is not supported. +This should generally raise an error, as it indicates a potential portability +issue, but this hasn't necessarily been tested thoroughly. + +1. array-api-strict is validated against the [array API test + suite](https://github.com/data-apis/array-api-tests). However, there may be + a few minor instances where NumPy deviates from the standard in a way that + is inconvenient to workaround in array-api-strict, since it aims to remain + pure Python. You can see the full list of tests that are known to fail in + the [xfails + file](https://github.com/data-apis/array-api-strict/blob/main/array-api-tests-xfails.txt). + + The most notable of these is that in NumPy 1.26, the `copy=False` flag is + not implemented for `asarray` and therefore `array_api_strict` raises + `NotImplementedError` in that case. + +2. Since NumPy is a CPU-only library, the [device + support](https://data-apis.org/array-api/latest/design_topics/device_support.html) + in array-api-strict is superficial only. `x.device` is always a (private) + `CPU_DEVICE` object, and `device` keywords to creation functions only + accept either this object or `None`. A future version of array-api-strict + [may add support for a CuPy + backend](https://github.com/data-apis/array-api-strict/issues/5) so that + more significant device support can be tested. + +3. Although only array types are expected in array-api-strict functions, + currently most functions do not do extensive type checking on their inputs, + so a sufficiently duck-typed object may pass through silently (or at best, + you may get `AttributeError` instead of `TypeError`). However, all type + signatures have type annotations (based on those from the standard), so + this deviation may be tested with type checking. This [behavior may improve + in the future](https://github.com/data-apis/array-api-strict/issues/6). + +4. There are some behaviors in the standard that are not required to be + implemented by libraries that cannot support [data dependent + shapes](https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html). + This includes [the `unique_*` + functions](https://data-apis.org/array-api/latest/API_specification/set_functions.html), + [boolean array + indexing](https://data-apis.org/array-api/latest/API_specification/indexing.html#boolean-array-indexing), + and the + [`nonzero`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.nonzero.html) + function. array-api-strict currently implements all of these. In the + future, [there may be a way to disable them](https://github.com/data-apis/array-api-strict/issues/7). + +5. array-api-strict currently only supports the latest version of the array + API standard. [This may change in the future depending on + need](https://github.com/data-apis/array-api-strict/issues/8). + +## Usage + +TODO: Add a sample CI script here. + +## Relationship to `numpy.array_api` + +Previously this implementation was available as `numpy.array_api`, but it was +moved to a separate package for NumPy 2.0. + +Note that the history of this repo prior to commit +fbefd42e4d11e9be20e0a4785f2619fc1aef1e7c was generated automatically +from the numpy git history, using the following +[git-filter-repo](https://github.com/newren/git-filter-repo) command: + +``` +git_filter_repo.py --path numpy/array_api/ --path-rename numpy/array_api:array_api_strict --replace-text <(echo -e "numpy.array_api==>array_api_strict\nfrom ..core==>from numpy.core\nfrom .._core==>from numpy._core\nfrom ..linalg==>from numpy.linalg\nfrom numpy import array_api==>import array_api_strict") --commit-callback 'commit.message = commit.message.rstrip() + b"\n\nOriginal NumPy Commit: " + commit.original_id' +``` + +```{toctree} +:titlesonly: +:hidden: +``` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..32bb245 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..dbec774 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,6 @@ +furo +linkify-it-py +myst-parser +sphinx +sphinx-copybutton +sphinx-autobuild From 166982c010ce7d4753a9438235bfccf1681f019d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 29 Mar 2024 17:08:32 -0600 Subject: [PATCH 002/252] Move the changelog into the docs --- CHANGELOG.md | 46 +--------------------------------------------- docs/changelog.md | 45 +++++++++++++++++++++++++++++++++++++++++++++ docs/index.md | 2 ++ 3 files changed, 48 insertions(+), 45 deletions(-) mode change 100644 => 120000 CHANGELOG.md create mode 100644 docs/changelog.md diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index e380bf3..0000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,45 +0,0 @@ -# array-api-strict Changelog - -## 1.1.1 (2024-04-29) - -- Fix the `api_version` argument to `__array_namespace__` to accept - `'2021.12'` or `'2022.12'`. - -## 1.1 (2024-04-08) - -- Fix the `copy` flag in `__array__` for NumPy 2.0. - -- Add full `copy=False` support to `asarray()`. This is emulated in NumPy 1.26 by creating - the array and seeing if it is copied. For NumPy 2.0, the new native - `copy=False` flag is used. - -- Add broadcasting support to `cross`. - -## 1.0 (2024-01-24) - -This is the first release of `array_api_strict`. It is extracted from -`numpy.array_api`, which was included as an experimental submodule in NumPy -versions prior to 2.0. Note that the commit history in this repository is -extracted from the git history of numpy/array_api/ (see the [README](README.md)). - -Additionally, the following changes are new to `array_api_strict` from -`numpy.array_api` in NumPy 1.26 (the last NumPy feature release to include -`numpy.array_api`): - -- ``array_api_strict`` was made more portable. In particular: - - - ``array_api_strict`` no longer uses ``"cpu"`` as its "device", but rather a - separate ``CPU_DEVICE`` object (which is not accessible in the namespace). - This is because "cpu" is not part of the array API standard. - - - ``array_api_strict`` now uses separate wrapped objects for dtypes. - Previously it reused the ``numpy`` dtype objects. This makes it clear - which behaviors on dtypes are part of the array API standard (effectively, - the standard only requires ``==`` on dtype objects). - -- ``numpy.array_api.nonzero`` now errors on zero-dimensional arrays, as - required by the array API standard. - -- Support for the optional [fft - extension](https://data-apis.org/array-api/latest/extensions/fourier_transform_functions.html) - was added. diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 120000 index 0000000..1bed66b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1 @@ +docs/changelog.md \ No newline at end of file diff --git a/docs/changelog.md b/docs/changelog.md new file mode 100644 index 0000000..8f1c203 --- /dev/null +++ b/docs/changelog.md @@ -0,0 +1,45 @@ +# Changelog + +## 1.1.1 (2024-04-29) + +- Fix the `api_version` argument to `__array_namespace__` to accept + `'2021.12'` or `'2022.12'`. + +## 1.1 (2024-04-08) + +- Fix the `copy` flag in `__array__` for NumPy 2.0. + +- Add full `copy=False` support to `asarray()`. This is emulated in NumPy 1.26 by creating + the array and seeing if it is copied. For NumPy 2.0, the new native + `copy=False` flag is used. + +- Add broadcasting support to `cross`. + +## 1.0 (2024-01-24) + +This is the first release of `array_api_strict`. It is extracted from +`numpy.array_api`, which was included as an experimental submodule in NumPy +versions prior to 2.0. Note that the commit history in this repository is +extracted from the git history of numpy/array_api/ (see the [README](README.md)). + +Additionally, the following changes are new to `array_api_strict` from +`numpy.array_api` in NumPy 1.26 (the last NumPy feature release to include +`numpy.array_api`): + +- ``array_api_strict`` was made more portable. In particular: + + - ``array_api_strict`` no longer uses ``"cpu"`` as its "device", but rather a + separate ``CPU_DEVICE`` object (which is not accessible in the namespace). + This is because "cpu" is not part of the array API standard. + + - ``array_api_strict`` now uses separate wrapped objects for dtypes. + Previously it reused the ``numpy`` dtype objects. This makes it clear + which behaviors on dtypes are part of the array API standard (effectively, + the standard only requires ``==`` on dtype objects). + +- ``numpy.array_api.nonzero`` now errors on zero-dimensional arrays, as + required by the array API standard. + +- Support for the optional [fft + extension](https://data-apis.org/array-api/latest/extensions/fourier_transform_functions.html) + was added. diff --git a/docs/index.md b/docs/index.md index 7c222e3..c615634 100644 --- a/docs/index.md +++ b/docs/index.md @@ -203,4 +203,6 @@ git_filter_repo.py --path numpy/array_api/ --path-rename numpy/array_api:array_a ```{toctree} :titlesonly: :hidden: + +changelog.md ``` From 00d7ac5f93da0861d9dd35a1ec12a5d3683381a5 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 29 Mar 2024 17:17:20 -0600 Subject: [PATCH 003/252] Clear most of the text in the README --- README.md | 183 +----------------------------------------------------- 1 file changed, 2 insertions(+), 181 deletions(-) diff --git a/README.md b/README.md index d4a3224..b392ff3 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # array-api-strict `array_api_strict` is a strict, minimal implementation of the [Python array -API](https://data-apis.org/array-api/latest/) +API](https://data-apis.org/array-api/latest/). The purpose of array-api-strict is to provide an implementation of the array API for consuming libraries to test against so they can be completely sure @@ -19,183 +19,4 @@ array-api-strict currently supports the 2022.12 version of the standard. 2023.12 support is planned and is tracked by [this issue](https://github.com/data-apis/array-api-strict/issues/25). -## Install - -`array-api-strict` is available on both -[PyPI](https://pypi.org/project/array-api-strict/) - -``` -python -m pip install array-api-strict -``` - -and [Conda-forge](https://anaconda.org/conda-forge/array-api-strict) - -``` -conda install --channel conda-forge array-api-strict -``` - -array-api-strict supports NumPy 1.26 and (the upcoming) NumPy 2.0. - -## Rationale - -The array API has many functions and behaviors that are required to be -implemented by conforming libraries, but it does not, in most cases, disallow -implementing additional functions, keyword arguments, and behaviors that -aren't explicitly required by the standard. - -However, this poses a problem for consumers of the array API, as they may -accidentally use a function or rely on a behavior which just happens to be -implemented in every array library they test against (e.g., NumPy and -PyTorch), but isn't required by the standard and may not be included in other -libraries. - -array-api-strict solves this problem by providing a strict, minimal -implementation of the array API standard. Only those functions and behaviors -that are explicitly *required* by the standard are implemented. For example, -most NumPy functions accept Python scalars as inputs: - -```py ->>> import numpy as np ->>> np.sin(0.0) -0.0 -``` - -However, the standard only specifies function inputs on `Array` objects. And -indeed, some libraries, such as PyTorch, do not allow this: - -```py ->>> import torch ->>> torch.sin(0.0) -Traceback (most recent call last): - File "", line 1, in -TypeError: sin(): argument 'input' (position 1) must be Tensor, not float -``` - -In array-api-strict, this is also an error: - -```py ->>> import array_api_strict as xp ->>> xp.sin(0.0) -Traceback (most recent call last): -... -AttributeError: 'float' object has no attribute 'dtype' -``` - -Here is an (incomplete) list of the sorts of ways that array-api-strict is -strict/minimal: - -- Only those functions and methods that are [defined in the - standard](https://data-apis.org/array-api/latest/API_specification/index.html) - are included. - -- In those functions, only the keyword-arguments that are defined by the - standard are included. All signatures in array-api-strict use - [positional-only - arguments](https://data-apis.org/array-api/latest/API_specification/function_and_method_signatures.html#function-and-method-signatures). - As noted above, only `array_api_strict` array objects are accepted by - functions, except in the places where the standard allows Python scalars - (i.e., functions do not automatically call `asarray` on their inputs). - -- Only those [dtypes that are defined in the - standard](https://data-apis.org/array-api/latest/API_specification/data_types.html) - are included. - -- All functions and methods reject inputs if the standard does not *require* - the input dtype(s) to be supported. This is one of the most restrictive - aspects of the library. For example, in NumPy, most transcendental functions - like `sin` will accept integer array inputs, but the [standard only requires - them to accept floating-point - inputs](https://data-apis.org/array-api/latest/API_specification/generated/array_api.sin.html#array_api.sin), - so in array-api-strict, `sin(integer_array)` will raise an exception. - -- The - [indexing](https://data-apis.org/array-api/latest/API_specification/indexing.html) - semantics required by the standard are limited compared to those implemented - by NumPy (e.g., out-of-bounds slices are not supported, integer array - indexing is not supported, only a single boolean array index is supported). - -- There are no distinct "scalar" objects as in NumPy. There are only 0-D - arrays. - -- Dtype objects are just empty objects that only implement [equality - comparison](https://data-apis.org/array-api/latest/API_specification/generated/array_api.data_types.__eq__.html). - The way to access dtype objects in the standard is by name, like - `xp.float32`. - -- The array object type itself is private and should not be accessed. - Subclassing or otherwise trying to directly initialize this object is not - supported. Arrays should be created with one of the [array creation - functions](https://data-apis.org/array-api/latest/API_specification/creation_functions.html) - such as `asarray`. - -## Caveats - -array-api-strict is a thin pure Python wrapper around NumPy. NumPy 2.0 fully -supports the array API but NumPy 1.26 does not, so many behaviors are wrapped -in NumPy 1.26 to provide array API compatible behavior. Although it is based -on NumPy, mixing NumPy arrays with array-api-strict arrays is not supported. -This should generally raise an error, as it indicates a potential portability -issue, but this hasn't necessarily been tested thoroughly. - -1. array-api-strict is validated against the [array API test - suite](https://github.com/data-apis/array-api-tests). However, there may be - a few minor instances where NumPy deviates from the standard in a way that - is inconvenient to workaround in array-api-strict, since it aims to remain - pure Python. You can see the full list of tests that are known to fail in - the [xfails - file](https://github.com/data-apis/array-api-strict/blob/main/array-api-tests-xfails.txt). - - The most notable of these is that in NumPy 1.26, the `copy=False` flag is - not implemented for `asarray` and therefore `array_api_strict` raises - `NotImplementedError` in that case. - -2. Since NumPy is a CPU-only library, the [device - support](https://data-apis.org/array-api/latest/design_topics/device_support.html) - in array-api-strict is superficial only. `x.device` is always a (private) - `CPU_DEVICE` object, and `device` keywords to creation functions only - accept either this object or `None`. A future version of array-api-strict - [may add support for a CuPy - backend](https://github.com/data-apis/array-api-strict/issues/5) so that - more significant device support can be tested. - -3. Although only array types are expected in array-api-strict functions, - currently most functions do not do extensive type checking on their inputs, - so a sufficiently duck-typed object may pass through silently (or at best, - you may get `AttributeError` instead of `TypeError`). However, all type - signatures have type annotations (based on those from the standard), so - this deviation may be tested with type checking. This [behavior may improve - in the future](https://github.com/data-apis/array-api-strict/issues/6). - -4. There are some behaviors in the standard that are not required to be - implemented by libraries that cannot support [data dependent - shapes](https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html). - This includes [the `unique_*` - functions](https://data-apis.org/array-api/latest/API_specification/set_functions.html), - [boolean array - indexing](https://data-apis.org/array-api/latest/API_specification/indexing.html#boolean-array-indexing), - and the - [`nonzero`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.nonzero.html) - function. array-api-strict currently implements all of these. In the - future, [there may be a way to disable them](https://github.com/data-apis/array-api-strict/issues/7). - -5. array-api-strict currently only supports the latest version of the array - API standard. [This may change in the future depending on - need](https://github.com/data-apis/array-api-strict/issues/8). - -## Usage - -TODO: Add a sample CI script here. - -## Relationship to `numpy.array_api` - -Previously this implementation was available as `numpy.array_api`, but it was -moved to a separate package for NumPy 2.0. - -Note that the history of this repo prior to commit -fbefd42e4d11e9be20e0a4785f2619fc1aef1e7c was generated automatically -from the numpy git history, using the following -[git-filter-repo](https://github.com/newren/git-filter-repo) command: - -``` -git_filter_repo.py --path numpy/array_api/ --path-rename numpy/array_api:array_api_strict --replace-text <(echo -e "numpy.array_api==>array_api_strict\nfrom ..core==>from numpy.core\nfrom .._core==>from numpy._core\nfrom ..linalg==>from numpy.linalg\nfrom numpy import array_api==>import array_api_strict") --commit-callback 'commit.message = commit.message.rstrip() + b"\n\nOriginal NumPy Commit: " + commit.original_id' -``` +See the documentation for more details https://data-apis.org/array-api-strict/ From 00e91856b5b8939136901421cc2c6eafcbb18bf3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 29 Mar 2024 17:17:32 -0600 Subject: [PATCH 004/252] Some small updates to the docs --- docs/index.md | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/docs/index.md b/docs/index.md index c615634..8672cd6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -145,11 +145,12 @@ issue, but this hasn't necessarily been tested thoroughly. the [xfails file](https://github.com/data-apis/array-api-strict/blob/main/array-api-tests-xfails.txt). - The most notable of these is that in NumPy 1.26, the `copy=False` flag is - not implemented for `asarray` and therefore `array_api_strict` raises - `NotImplementedError` in that case. +2. array-api-strict is just a thin, pure Python wrapper around `numpy.ndarray` + and `numpy` functions. As such, the performance should be mostly comparable + to NumPy, but nonetheless, performance is not a primary concern for this + library, since it is only intended to be used for testing purposes. -2. Since NumPy is a CPU-only library, the [device +3. Since NumPy is a CPU-only library, the [device support](https://data-apis.org/array-api/latest/design_topics/device_support.html) in array-api-strict is superficial only. `x.device` is always a (private) `CPU_DEVICE` object, and `device` keywords to creation functions only @@ -158,7 +159,7 @@ issue, but this hasn't necessarily been tested thoroughly. backend](https://github.com/data-apis/array-api-strict/issues/5) so that more significant device support can be tested. -3. Although only array types are expected in array-api-strict functions, +4. Although only array types are expected in array-api-strict functions, currently most functions do not do extensive type checking on their inputs, so a sufficiently duck-typed object may pass through silently (or at best, you may get `AttributeError` instead of `TypeError`). However, all type @@ -166,7 +167,7 @@ issue, but this hasn't necessarily been tested thoroughly. this deviation may be tested with type checking. This [behavior may improve in the future](https://github.com/data-apis/array-api-strict/issues/6). -4. There are some behaviors in the standard that are not required to be +5. There are some behaviors in the standard that are not required to be implemented by libraries that cannot support [data dependent shapes](https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html). This includes [the `unique_*` @@ -178,9 +179,9 @@ issue, but this hasn't necessarily been tested thoroughly. function. array-api-strict currently implements all of these. In the future, [there may be a way to disable them](https://github.com/data-apis/array-api-strict/issues/7). -5. array-api-strict currently only supports the latest version of the array - API standard. [This may change in the future depending on - need](https://github.com/data-apis/array-api-strict/issues/8). +6. array-api-strict currently only supports the 2022.12 version of the array + API standard. [Support for 2023.12 is + planned](https://github.com/data-apis/array-api-strict/issues/25). ## Usage @@ -192,8 +193,8 @@ Previously this implementation was available as `numpy.array_api`, but it was moved to a separate package for NumPy 2.0. Note that the history of this repo prior to commit -fbefd42e4d11e9be20e0a4785f2619fc1aef1e7c was generated automatically -from the numpy git history, using the following +[fbefd42e4d11e9be20e0a4785f2619fc1aef1e7c](https://github.com/data-apis/array-api-strict/commit/fbefd42e4d11e9be20e0a4785f2619fc1aef1e7c) +was generated automatically from the numpy git history, using the following [git-filter-repo](https://github.com/newren/git-filter-repo) command: ``` From 20e1f37fe0a34ff0ec073794bd75461454587b91 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 29 Mar 2024 17:22:08 -0600 Subject: [PATCH 005/252] Remove the version number from the docs Since we deploy on every commit, the version number will always end up containing the git hash. --- docs/conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d6550e8..c068b06 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,8 +14,8 @@ copyright = '2024, Consortium for Python Data API Standards' author = 'Consortium for Python Data API Standards' -import array_api_strict -release = array_api_strict.__version__ +# import array_api_strict +# release = array_api_strict.__version__ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration From afdc760ee04fe3200044fd515cf20e7b9e257e0c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 29 Mar 2024 17:23:33 -0600 Subject: [PATCH 006/252] Add numpy to the dev requirements files --- docs/requirements.txt | 3 ++- requirements-dev.txt | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index dbec774..464a843 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,7 @@ furo linkify-it-py myst-parser +numpy sphinx -sphinx-copybutton sphinx-autobuild +sphinx-copybutton diff --git a/requirements-dev.txt b/requirements-dev.txt index 9b9e5f9..137e973 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,2 +1,3 @@ pytest hypothesis +numpy From 90b5075dbd4f248122ce6358565eb1b80b2859eb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 29 Mar 2024 17:24:59 -0600 Subject: [PATCH 007/252] Small docs cleanups --- docs/index.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/index.md b/docs/index.md index 8672cd6..4f3f00b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -183,16 +183,12 @@ issue, but this hasn't necessarily been tested thoroughly. API standard. [Support for 2023.12 is planned](https://github.com/data-apis/array-api-strict/issues/25). -## Usage - -TODO: Add a sample CI script here. - ## Relationship to `numpy.array_api` Previously this implementation was available as `numpy.array_api`, but it was moved to a separate package for NumPy 2.0. -Note that the history of this repo prior to commit +The history of this repo prior to commit [fbefd42e4d11e9be20e0a4785f2619fc1aef1e7c](https://github.com/data-apis/array-api-strict/commit/fbefd42e4d11e9be20e0a4785f2619fc1aef1e7c) was generated automatically from the numpy git history, using the following [git-filter-repo](https://github.com/newren/git-filter-repo) command: From a97daf93f29df969fceadbff21965d5361b9229c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 29 Mar 2024 17:29:40 -0600 Subject: [PATCH 008/252] Update links to the array-api-compat docs --- README.md | 2 +- array_api_strict/__init__.py | 2 +- docs/index.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b392ff3..8172237 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ It is *not* intended to be used by end-users. End-users of the array API should just use their favorite array library (NumPy, CuPy, PyTorch, etc.) as usual. It is also not intended to be used as a dependency by consuming libraries. Consuming library code should use the -[array-api-compat](https://github.com/data-apis/array-api-compat) package to +[array-api-compat](https://data-apis.org/array-api-compat/) package to support the array API. Rather, it is intended to be used in the test suites of consuming libraries to test their array API usage. diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 326f55d..90c82c2 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -10,7 +10,7 @@ should just use their favorite array library (NumPy, CuPy, PyTorch, etc.) as usual. It is also not intended to be used as a dependency by consuming libraries. Consuming library code should use the -array-api-compat (https://github.com/data-apis/array-api-compat) package to +array-api-compat (https://data-apis.org/array-api-compat/) package to support the array API. Rather, it is intended to be used in the test suites of consuming libraries to test their array API usage. diff --git a/docs/index.md b/docs/index.md index 4f3f00b..307a9c2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -11,7 +11,7 @@ It is *not* intended to be used by end-users. End-users of the array API should just use their favorite array library (NumPy, CuPy, PyTorch, etc.) as usual. It is also not intended to be used as a dependency by consuming libraries. Consuming library code should use the -[array-api-compat](https://github.com/data-apis/array-api-compat) package to +[array-api-compat](https://data-apis.org/array-api-compat/) package to support the array API. Rather, it is intended to be used in the test suites of consuming libraries to test their array API usage. From 1680f2ced9826bb5f8f70584c3681bb4235799c3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Apr 2024 21:58:35 +0000 Subject: [PATCH 009/252] Bump the actions group with 1 update Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 2 to 3 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v2...v3) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 7560028..79cde63 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v2 + uses: dawidd6/action-download-artifact@v3 with: workflow: docs-build.yml name: docs-build From 2cd4c3f58f769c73a4c905593d90d678517d41da Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 8 Apr 2024 16:03:25 -0600 Subject: [PATCH 010/252] Set up basic structure for array-api-strict flags Flags are global variables that set array-api-strict in a specific mode. Currently support flags change the support array API standard version, enable or disable data-dependent shapes, and enable or disable optional extensions. This commit only sets up the structure for setting and getting these flags. --- array_api_strict/__init__.py | 11 ++ array_api_strict/_flags.py | 256 +++++++++++++++++++++++++++++++++++ 2 files changed, 267 insertions(+) create mode 100644 array_api_strict/_flags.py diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 90c82c2..1fab8a2 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -284,6 +284,17 @@ __all__ += ["all", "any"] +# Helper functions that are not part of the standard + +from ._flags import ( + set_array_api_strict_flags, + get_array_api_strict_flags, + reset_array_api_strict_flags, + ArrayApiStrictFlags, +) + +__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayApiStrictFlags'] + from . import _version __version__ = _version.get_versions()['version'] del _version diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py new file mode 100644 index 0000000..0d76903 --- /dev/null +++ b/array_api_strict/_flags.py @@ -0,0 +1,256 @@ +""" +This file defines flags for that allow array-api-strict to be used in +different "modes". These modes include + +- Changing to different supported versions of the standard. +- Enabling or disabling different optional behaviors (such as data-dependent + shapes). +- Enabling or disabling different optional extensions. + +Nothing in this file is part of the standard itself. A typical array API +library will only support one particular configuration of these flags. +""" + +import os + +supported_versions = [ + "2021.12", + "2022.12", +] + +STANDARD_VERSION = "2022.12" + +DATA_DEPENDENT_SHAPES = True + +all_extensions = [ + "linalg", + "fft", +] + +extension_versions = { + "linalg": "2021.12", + "fft": "2022.12", +} + +ENABLED_EXTENSIONS = [ + "linalg", + "fft", +] + +def set_array_api_strict_flags( + *, + standard_version=None, + data_dependent_shapes=None, + enabled_extensions=None, +): + """ + Set the array-api-strict flags to the specified values. + + Flags are global variables that enable or disable array-api-strict + behaviors. + + .. note:: + + This function is **not** part of the array API standard. It only exists + in array-api-strict. + + - `standard_version`: The version of the standard to use. Supported + versions are: ``{supported_versions}``. The default version number is + ``{default_version!r}``. + + - `data_dependent_shapes`: Whether data-dependent shapes are enabled in + array-api-strict. This flag is enabled by default. Array libraries that + use computation graphs may not be able to support functions whose output + shapes depend on the input data. + + This flag is enabled by default. Array libraries that use computation graphs may not be able to support + functions whose output shapes depend on the input data. + + The functions that make use of data-dependent shapes, and are therefore + disabled by setting this flag to False are + + - `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`. + - `nonzero` + - Boolean array indexing + - `repeat` when the `repeats` argument is an array (requires 2023.12 + version of the standard) + + See + https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html + for more details. + + - `enabled_extensions`: A list of extensions that are enabled in + array-api-strict. The default is ``{default_extensions}``. Note that + some extensions require a minimum version of the standard. + + The default values of the flags can also be changed by setting environment + variables: + + - ``ARRAY_API_STRICT_STANDARD_VERSION``: A string representing the version number. + - ``ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES``: "True" or "False". + - ``ARRAY_API_STRICT_ENABLED_EXTENSIONS``: A comma separated list of + extensions to enable. + + Examples + -------- + + >>> from array_api_strict import set_array_api_strict_flags + >>> # Set the standard version to 2021.12 + >>> set_array_api_strict_flags(standard_version="2021.12") + >>> # Disable data-dependent shapes + >>> set_array_api_strict_flags(data_dependent_shapes=False) + >>> # Enable only the linalg extension (disable the fft extension) + >>> set_array_api_strict_flags(enabled_extensions=["linalg"]) + + See Also + -------- + + get_array_api_strict_flags + reset_array_api_strict_flags + ArrayApiStrictFlags: A context manager to temporarily set the flags. + + """ + global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS + + if standard_version is not None: + if standard_version not in supported_versions: + raise ValueError(f"Unsupported standard version {standard_version}") + STANDARD_VERSION = standard_version + + if data_dependent_shapes is not None: + DATA_DEPENDENT_SHAPES = data_dependent_shapes + + if enabled_extensions is not None: + for extension in enabled_extensions: + if extension not in all_extensions: + raise ValueError(f"Unsupported extension {extension}") + if extension_versions[extension] > STANDARD_VERSION: + raise ValueError( + f"Extension {extension} requires standard version " + f"{extension_versions[extension]} or later" + ) + ENABLED_EXTENSIONS = enabled_extensions + +# We have to do this separately or it won't get added as the docstring +set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( + supported_versions=supported_versions, + default_version=STANDARD_VERSION, + default_extensions=ENABLED_EXTENSIONS, +) + +def get_array_api_strict_flags(): + """ + Get the current array-api-strict flags. + + .. note:: + + This function is **not** part of the array API standard. It only exists + in array-api-strict. + + Returns + ------- + dict + A dictionary containing the current array-api-strict flags. + + Examples + -------- + + >>> from array_api_strict import get_array_api_strict_flags + >>> flags = get_array_api_strict_flags() + >>> flags + {'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ['linalg', 'fft']} + + See Also + -------- + + set_array_api_strict_flags + reset_array_api_strict_flags + ArrayApiStrictFlags: A context manager to temporarily set the flags. + + """ + return { + "standard_version": STANDARD_VERSION, + "data_dependent_shapes": DATA_DEPENDENT_SHAPES, + "enabled_extensions": ENABLED_EXTENSIONS, + } + + +def reset_array_api_strict_flags(): + """ + Reset the array-api-strict flags to their default values. + + .. note:: + + This function is **not** part of the array API standard. It only exists + in array-api-strict. + + Examples + -------- + + >>> from array_api_strict import reset_array_api_strict_flags + >>> reset_array_api_strict_flags() + + See Also + -------- + + set_array_api_strict_flags + get_array_api_strict_flags + ArrayApiStrictFlags: A context manager to temporarily set the flags. + + """ + global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS + STANDARD_VERSION = "2022.12" + DATA_DEPENDENT_SHAPES = True + ENABLED_EXTENSIONS = ["linalg", "fft"] + + +class ArrayApiStrictFlags: + """ + A context manager to temporarily set the array-api-strict flags. + + .. note:: + + This class is **not** part of the array API standard. It only exists + in array-api-strict. + + See :func:`~.array_api_strict.set_array_api_strict_flags` for a + description of the available flags. + + See Also + -------- + + set_array_api_strict_flags + get_array_api_strict_flags + reset_array_api_strict_flags + + """ + def __init__(self, *, standard_version=None, data_dependent_shapes=None, + enabled_extensions=None): + self.kwargs = { + "standard_version": standard_version, + "data_dependent_shapes": data_dependent_shapes, + "enabled_extensions": enabled_extensions, + } + self.old_flags = get_array_api_strict_flags() + + def __enter__(self): + set_array_api_strict_flags(**self.kwargs) + + def __exit__(self, exc_type, exc_value, traceback): + set_array_api_strict_flags(**self.old_flags) + +# Set the flags from the environment variables +if "ARRAY_API_STRICT_STANDARD_VERSION" in os.environ: + set_array_api_strict_flags( + standard_version=os.environ["ARRAY_API_STRICT_STANDARD_VERSION"] + ) + +if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: + set_array_api_strict_flags( + data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" + ) + +if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ: + set_array_api_strict_flags( + enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") + ) From d8c3745372e2d87b0be8bc87c5709246fe8d455a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 9 Apr 2024 17:27:57 -0600 Subject: [PATCH 011/252] Disable extensions when setting the standard version --- array_api_strict/_flags.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 0d76903..a26258e 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -129,7 +129,9 @@ def set_array_api_strict_flags( f"Extension {extension} requires standard version " f"{extension_versions[extension]} or later" ) - ENABLED_EXTENSIONS = enabled_extensions + ENABLED_EXTENSIONS = tuple(enabled_extensions) + else: + ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= STANDARD_VERSION]) # We have to do this separately or it won't get added as the docstring set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( From f34576c63b2d2c7f3c07cf03b2d48511010118fe Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 9 Apr 2024 17:29:08 -0600 Subject: [PATCH 012/252] Some small code cleanups to the flags file --- array_api_strict/_flags.py | 63 ++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index a26258e..1d50ba3 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -11,31 +11,34 @@ library will only support one particular configuration of these flags. """ +import functools import os -supported_versions = [ +supported_versions = ( "2021.12", "2022.12", -] +) -STANDARD_VERSION = "2022.12" +STANDARD_VERSION = default_version = "2022.12" DATA_DEPENDENT_SHAPES = True -all_extensions = [ +all_extensions = ( "linalg", "fft", -] +) extension_versions = { "linalg": "2021.12", "fft": "2022.12", } -ENABLED_EXTENSIONS = [ +ENABLED_EXTENSIONS = default_extensions = ( "linalg", "fft", -] +) + +# Public functions def set_array_api_strict_flags( *, @@ -136,8 +139,8 @@ def set_array_api_strict_flags( # We have to do this separately or it won't get added as the docstring set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( supported_versions=supported_versions, - default_version=STANDARD_VERSION, - default_extensions=ENABLED_EXTENSIONS, + default_version=default_version, + default_extensions=default_extensions, ) def get_array_api_strict_flags(): @@ -160,7 +163,7 @@ def get_array_api_strict_flags(): >>> from array_api_strict import get_array_api_strict_flags >>> flags = get_array_api_strict_flags() >>> flags - {'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ['linalg', 'fft']} + {'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')} See Also -------- @@ -181,6 +184,8 @@ def reset_array_api_strict_flags(): """ Reset the array-api-strict flags to their default values. + This will also reset any flags that were set by environment variables. + .. note:: This function is **not** part of the array API standard. It only exists @@ -201,9 +206,9 @@ def reset_array_api_strict_flags(): """ global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS - STANDARD_VERSION = "2022.12" + STANDARD_VERSION = default_version DATA_DEPENDENT_SHAPES = True - ENABLED_EXTENSIONS = ["linalg", "fft"] + ENABLED_EXTENSIONS = default_extensions class ArrayApiStrictFlags: @@ -241,18 +246,22 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): set_array_api_strict_flags(**self.old_flags) -# Set the flags from the environment variables -if "ARRAY_API_STRICT_STANDARD_VERSION" in os.environ: - set_array_api_strict_flags( - standard_version=os.environ["ARRAY_API_STRICT_STANDARD_VERSION"] - ) - -if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: - set_array_api_strict_flags( - data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" - ) - -if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ: - set_array_api_strict_flags( - enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") - ) +# Private functions + +def set_flags_from_environment(): + if "ARRAY_API_STRICT_STANDARD_VERSION" in os.environ: + set_array_api_strict_flags( + standard_version=os.environ["ARRAY_API_STRICT_STANDARD_VERSION"] + ) + + if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: + set_array_api_strict_flags( + data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" + ) + + if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ: + set_array_api_strict_flags( + enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") + ) + +set_flags_from_environment() From 6a20e916827ae4c7c46a00e82c9dd8621c3285b8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 9 Apr 2024 17:29:27 -0600 Subject: [PATCH 013/252] Add functionality for the data_dependent_shapes flag --- array_api_strict/_array_object.py | 19 ++++++++++++------- array_api_strict/_flags.py | 8 ++++++++ array_api_strict/_searching_functions.py | 2 ++ array_api_strict/_set_functions.py | 6 ++++++ 4 files changed, 28 insertions(+), 7 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 39808f0..e58767f 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -32,6 +32,7 @@ _result_type, _dtype_categories, ) +from ._flags import get_array_api_strict_flags from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex import types @@ -427,13 +428,17 @@ def _validate_index(self, key): "the Array API)" ) elif isinstance(i, Array): - if i.dtype in _boolean_dtypes and len(_key) != 1: - assert isinstance(key, tuple) # sanity check - raise IndexError( - f"Single-axes index {i} is a boolean array and " - f"{len(key)=}, but masking is only specified in the " - "Array API when the array is the sole index." - ) + if i.dtype in _boolean_dtypes: + if len(_key) != 1: + assert isinstance(key, tuple) # sanity check + raise IndexError( + f"Single-axes index {i} is a boolean array and " + f"{len(key)=}, but masking is only specified in the " + "Array API when the array is the sole index." + ) + if not get_array_api_strict_flags()['data_dependent_shapes']: + raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") + elif i.dtype in _integer_dtypes and i.ndim != 0: raise IndexError( f"Single-axes index {i} is a non-zero-dimensional " diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 1d50ba3..33599ea 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -265,3 +265,11 @@ def set_flags_from_environment(): ) set_flags_from_environment() + +def requires_data_dependent_shapes(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not DATA_DEPENDENT_SHAPES: + raise RuntimeError(f"The function {func.__name__} requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") + return func(*args, **kwargs) + return wrapper diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index f4b2f56..9781531 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -2,6 +2,7 @@ from ._array_object import Array from ._dtypes import _result_type, _real_numeric_dtypes +from ._flags import requires_data_dependent_shapes from typing import Optional, Tuple @@ -30,6 +31,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) +@requires_data_dependent_shapes def nonzero(x: Array, /) -> Tuple[Array, ...]: """ Array API compatible wrapper for :py:func:`np.nonzero `. diff --git a/array_api_strict/_set_functions.py b/array_api_strict/_set_functions.py index 0b4132c..e6ca939 100644 --- a/array_api_strict/_set_functions.py +++ b/array_api_strict/_set_functions.py @@ -2,6 +2,8 @@ from ._array_object import Array +from ._flags import requires_data_dependent_shapes + from typing import NamedTuple import numpy as np @@ -35,6 +37,7 @@ class UniqueInverseResult(NamedTuple): inverse_indices: Array +@requires_data_dependent_shapes def unique_all(x: Array, /) -> UniqueAllResult: """ Array API compatible wrapper for :py:func:`np.unique `. @@ -59,6 +62,7 @@ def unique_all(x: Array, /) -> UniqueAllResult: ) +@requires_data_dependent_shapes def unique_counts(x: Array, /) -> UniqueCountsResult: res = np.unique( x._array, @@ -71,6 +75,7 @@ def unique_counts(x: Array, /) -> UniqueCountsResult: return UniqueCountsResult(*[Array._new(i) for i in res]) +@requires_data_dependent_shapes def unique_inverse(x: Array, /) -> UniqueInverseResult: """ Array API compatible wrapper for :py:func:`np.unique `. @@ -90,6 +95,7 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult: return UniqueInverseResult(Array._new(values), Array._new(inverse_indices)) +@requires_data_dependent_shapes def unique_values(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.unique `. From 4705b9fb25b075bb2c7695acc6dba88fc6b79fa9 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 9 Apr 2024 17:29:38 -0600 Subject: [PATCH 014/252] Add tests for flags --- array_api_strict/tests/test_flags.py | 78 ++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 array_api_strict/tests/test_flags.py diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py new file mode 100644 index 0000000..ede4b96 --- /dev/null +++ b/array_api_strict/tests/test_flags.py @@ -0,0 +1,78 @@ +from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, + reset_array_api_strict_flags) + +from .. import (asarray, unique_all, unique_counts, unique_inverse, + unique_values, nonzero) + +import pytest + +@pytest.fixture(autouse=True) +def reset_flags(): + reset_array_api_strict_flags() + yield + reset_array_api_strict_flags() + +def test_flags(): + # Test defaults + flags = get_array_api_strict_flags() + assert flags == { + 'standard_version': '2022.12', + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } + + # Test setting flags + set_array_api_strict_flags(data_dependent_shapes=False) + flags = get_array_api_strict_flags() + assert flags == { + 'standard_version': '2022.12', + 'data_dependent_shapes': False, + 'enabled_extensions': ('linalg', 'fft'), + } + set_array_api_strict_flags(enabled_extensions=('fft',)) + flags = get_array_api_strict_flags() + assert flags == { + 'standard_version': '2022.12', + 'data_dependent_shapes': False, + 'enabled_extensions': ('fft',), + } + # Make sure setting the version to 2021.12 disables fft + set_array_api_strict_flags(standard_version='2021.12') + flags = get_array_api_strict_flags() + assert flags == { + 'standard_version': '2021.12', + 'data_dependent_shapes': False, + 'enabled_extensions': ('linalg',), + } + + # Test setting flags with invalid values + pytest.raises(ValueError, lambda: + set_array_api_strict_flags(standard_version='2020.12')) + pytest.raises(ValueError, lambda: set_array_api_strict_flags( + enabled_extensions=('linalg', 'fft', 'invalid'))) + pytest.raises(ValueError, lambda: set_array_api_strict_flags( + standard_version='2021.12', + enabled_extensions=('linalg', 'fft'))) + + +def test_data_dependent_shapes(): + a = asarray([0, 0, 1, 2, 2]) + mask = asarray([True, False, True, False, True]) + + # Should not error + unique_all(a) + unique_counts(a) + unique_inverse(a) + unique_values(a) + nonzero(a) + a[mask] + # TODO: add repeat when it is implemented + + set_array_api_strict_flags(data_dependent_shapes=False) + + pytest.raises(RuntimeError, lambda: unique_all(a)) + pytest.raises(RuntimeError, lambda: unique_counts(a)) + pytest.raises(RuntimeError, lambda: unique_inverse(a)) + pytest.raises(RuntimeError, lambda: unique_values(a)) + pytest.raises(RuntimeError, lambda: nonzero(a)) + pytest.raises(RuntimeError, lambda: a[mask]) From 689a776aa6a581c321b6793985c9c135880d149d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 12 Apr 2024 16:44:23 -0600 Subject: [PATCH 015/252] Respect the extension flag in linalg and fft This behavior still needs to be tested. This required moving the linalg functions that are also in the main namespace so that they can still work there even when the linalg extension is disabled. The way I've decided to implement this is that the functions will not raise an exception until they are called. It would probably be more convenient for users if they raised an attribute error, or if the extension namespace itself did, like it would in a real library without the given extension. But the implementation for this would be a lot more complicated and didn't really feel worth it to me. --- array_api_strict/__init__.py | 2 +- array_api_strict/_flags.py | 14 +++ array_api_strict/_linear_algebra_functions.py | 68 ++++++++++++ array_api_strict/fft.py | 15 +++ array_api_strict/linalg.py | 104 +++++++++--------- 5 files changed, 149 insertions(+), 54 deletions(-) create mode 100644 array_api_strict/_linear_algebra_functions.py diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 1fab8a2..31f0992 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -244,7 +244,7 @@ __all__ += ["linalg"] -from .linalg import matmul, tensordot, matrix_transpose, vecdot +from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot __all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"] diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 33599ea..bbe2c59 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -273,3 +273,17 @@ def wrapper(*args, **kwargs): raise RuntimeError(f"The function {func.__name__} requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") return func(*args, **kwargs) return wrapper + +def requires_extension(extension): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if extension not in ENABLED_EXTENSIONS: + if extension == 'linalg' \ + and func.__name__ in ['matmul', 'tensordot', + 'matrix_transpose', 'vecdot']: + raise RuntimeError(f"The linalg extension has been disabled for array-api-strict. However, {func.__name__} is also present in the main array_api_strict namespace and may be used from there.") + raise RuntimeError(f"The function {func.__name__} requires the {extension} extension, but it has been disabled for array-api-strict") + return func(*args, **kwargs) + return wrapper + return decorator diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py new file mode 100644 index 0000000..1ff08d4 --- /dev/null +++ b/array_api_strict/_linear_algebra_functions.py @@ -0,0 +1,68 @@ +""" +These functions are all also defined in the linalg extension, but we include +them here with wrappers in linalg so that the wrappers can be disabled if the +linalg extension is disabled in the flags. + +""" + +from __future__ import annotations + +from ._dtypes import _numeric_dtypes + +from ._array_object import Array + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ._typing import Sequence, Tuple, Union + +import numpy.linalg +import numpy as np + +# Note: matmul is the numpy top-level namespace but not in np.linalg +def matmul(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.matmul `. + + See its docstring for more information. + """ + # Note: the restriction to numeric dtypes only is different from + # np.matmul. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in matmul') + + return Array._new(np.matmul(x1._array, x2._array)) + +# Note: tensordot is the numpy top-level namespace but not in np.linalg + +# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. +def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: + # Note: the restriction to numeric dtypes only is different from + # np.tensordot. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in tensordot') + + return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) + +# Note: this function is new in the array API spec. Unlike transpose, it only +# transposes the last two axes. +def matrix_transpose(x: Array, /) -> Array: + if x.ndim < 2: + raise ValueError("x must be at least 2-dimensional for matrix_transpose") + return Array._new(np.swapaxes(x._array, -1, -2)) + +# Note: vecdot is not in NumPy +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in vecdot') + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + x1_, x2_ = np.broadcast_arrays(x1._array, x2._array) + x1_ = np.moveaxis(x1_, axis, -1) + x2_ = np.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return Array._new(res[..., 0, 0]) diff --git a/array_api_strict/fft.py b/array_api_strict/fft.py index b50e9e3..7f427e5 100644 --- a/array_api_strict/fft.py +++ b/array_api_strict/fft.py @@ -15,9 +15,11 @@ ) from ._array_object import Array, CPU_DEVICE from ._data_type_functions import astype +from ._flags import requires_extension import numpy as np +@requires_extension('fft') def fft( x: Array, /, @@ -40,6 +42,7 @@ def fft( return astype(res, complex64) return res +@requires_extension('fft') def ifft( x: Array, /, @@ -62,6 +65,7 @@ def ifft( return astype(res, complex64) return res +@requires_extension('fft') def fftn( x: Array, /, @@ -84,6 +88,7 @@ def fftn( return astype(res, complex64) return res +@requires_extension('fft') def ifftn( x: Array, /, @@ -106,6 +111,7 @@ def ifftn( return astype(res, complex64) return res +@requires_extension('fft') def rfft( x: Array, /, @@ -128,6 +134,7 @@ def rfft( return astype(res, complex64) return res +@requires_extension('fft') def irfft( x: Array, /, @@ -150,6 +157,7 @@ def irfft( return astype(res, float32) return res +@requires_extension('fft') def rfftn( x: Array, /, @@ -172,6 +180,7 @@ def rfftn( return astype(res, complex64) return res +@requires_extension('fft') def irfftn( x: Array, /, @@ -194,6 +203,7 @@ def irfftn( return astype(res, float32) return res +@requires_extension('fft') def hfft( x: Array, /, @@ -216,6 +226,7 @@ def hfft( return astype(res, float32) return res +@requires_extension('fft') def ihfft( x: Array, /, @@ -238,6 +249,7 @@ def ihfft( return astype(res, complex64) return res +@requires_extension('fft') def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.fftfreq `. @@ -248,6 +260,7 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar raise ValueError(f"Unsupported device {device!r}") return Array._new(np.fft.fftfreq(n, d=d)) +@requires_extension('fft') def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.rfftfreq `. @@ -258,6 +271,7 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A raise ValueError(f"Unsupported device {device!r}") return Array._new(np.fft.rfftfreq(n, d=d)) +@requires_extension('fft') def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.fftshift `. @@ -268,6 +282,7 @@ def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: raise TypeError("Only floating-point dtypes are allowed in fftshift") return Array._new(np.fft.fftshift(x._array, axes=axes)) +@requires_extension('fft') def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.ifftshift `. diff --git a/array_api_strict/linalg.py b/array_api_strict/linalg.py index 78e9ec4..e1998fa 100644 --- a/array_api_strict/linalg.py +++ b/array_api_strict/linalg.py @@ -12,6 +12,7 @@ from ._manipulation_functions import reshape from ._elementwise_functions import conj from ._array_object import Array +from ._flags import requires_extension try: from numpy._core.numeric import normalize_axis_tuple @@ -46,6 +47,7 @@ class SVDResult(NamedTuple): # Note: the inclusion of the upper keyword is different from # np.linalg.cholesky, which does not have it. +@requires_extension('linalg') def cholesky(x: Array, /, *, upper: bool = False) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.cholesky `. @@ -65,6 +67,7 @@ def cholesky(x: Array, /, *, upper: bool = False) -> Array: return Array._new(L) # Note: cross is the numpy top-level namespace, not np.linalg +@requires_extension('linalg') def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: """ Array API compatible wrapper for :py:func:`np.cross `. @@ -80,6 +83,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: raise ValueError('cross() dimension must equal 3') return Array._new(np.cross(x1._array, x2._array, axis=axis)) +@requires_extension('linalg') def det(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.det `. @@ -93,6 +97,7 @@ def det(x: Array, /) -> Array: return Array._new(np.linalg.det(x._array)) # Note: diagonal is the numpy top-level namespace, not np.linalg +@requires_extension('linalg') def diagonal(x: Array, /, *, offset: int = 0) -> Array: """ Array API compatible wrapper for :py:func:`np.diagonal `. @@ -103,7 +108,7 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array: # operates on the first two axes by default return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1)) - +@requires_extension('linalg') def eigh(x: Array, /) -> EighResult: """ Array API compatible wrapper for :py:func:`np.linalg.eigh `. @@ -120,6 +125,7 @@ def eigh(x: Array, /) -> EighResult: return EighResult(*map(Array._new, np.linalg.eigh(x._array))) +@requires_extension('linalg') def eigvalsh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.eigvalsh `. @@ -133,6 +139,7 @@ def eigvalsh(x: Array, /) -> Array: return Array._new(np.linalg.eigvalsh(x._array)) +@requires_extension('linalg') def inv(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.inv `. @@ -146,28 +153,13 @@ def inv(x: Array, /) -> Array: return Array._new(np.linalg.inv(x._array)) - -# Note: matmul is the numpy top-level namespace but not in np.linalg -def matmul(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.matmul `. - - See its docstring for more information. - """ - # Note: the restriction to numeric dtypes only is different from - # np.matmul. - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in matmul') - - return Array._new(np.matmul(x1._array, x2._array)) - - # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point # literals. +@requires_extension('linalg') def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm `. @@ -182,6 +174,7 @@ def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord)) +@requires_extension('linalg') def matrix_power(x: Array, n: int, /) -> Array: """ Array API compatible wrapper for :py:func:`np.matrix_power `. @@ -197,6 +190,7 @@ def matrix_power(x: Array, n: int, /) -> Array: return Array._new(np.linalg.matrix_power(x._array, n)) # Note: the keyword argument name rtol is different from np.linalg.matrix_rank +@requires_extension('linalg') def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.matrix_rank `. @@ -219,14 +213,8 @@ def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> A return Array._new(np.count_nonzero(S > tol, axis=-1)) -# Note: this function is new in the array API spec. Unlike transpose, it only -# transposes the last two axes. -def matrix_transpose(x: Array, /) -> Array: - if x.ndim < 2: - raise ValueError("x must be at least 2-dimensional for matrix_transpose") - return Array._new(np.swapaxes(x._array, -1, -2)) - # Note: outer is the numpy top-level namespace, not np.linalg +@requires_extension('linalg') def outer(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.outer `. @@ -245,6 +233,7 @@ def outer(x1: Array, x2: Array, /) -> Array: return Array._new(np.outer(x1._array, x2._array)) # Note: the keyword argument name rtol is different from np.linalg.pinv +@requires_extension('linalg') def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.pinv `. @@ -262,6 +251,7 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: rtol = max(x.shape[-2:]) * finfo(x.dtype).eps return Array._new(np.linalg.pinv(x._array, rcond=rtol)) +@requires_extension('linalg') def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: """ Array API compatible wrapper for :py:func:`np.linalg.qr `. @@ -277,6 +267,7 @@ def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRRe # np.linalg.qr, which only returns a tuple. return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode))) +@requires_extension('linalg') def slogdet(x: Array, /) -> SlogdetResult: """ Array API compatible wrapper for :py:func:`np.linalg.slogdet `. @@ -335,6 +326,7 @@ def _solve(a, b): return wrap(r.astype(result_t, copy=False)) +@requires_extension('linalg') def solve(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.solve `. @@ -348,6 +340,7 @@ def solve(x1: Array, x2: Array, /) -> Array: return Array._new(_solve(x1._array, x2._array)) +@requires_extension('linalg') def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: """ Array API compatible wrapper for :py:func:`np.linalg.svd `. @@ -365,23 +358,14 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). +@requires_extension('linalg') def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in svdvals') return Array._new(np.linalg.svd(x._array, compute_uv=False)) -# Note: tensordot is the numpy top-level namespace but not in np.linalg - -# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. -def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: - # Note: the restriction to numeric dtypes only is different from - # np.tensordot. - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in tensordot') - - return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) - # Note: trace is the numpy top-level namespace, not np.linalg +@requires_extension('linalg') def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.trace `. @@ -404,29 +388,12 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr # operates on the first two axes by default return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype))) -# Note: vecdot is not in NumPy -def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in vecdot') - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: - raise ValueError("x1 and x2 must have the same size along the given axis") - - x1_, x2_ = np.broadcast_arrays(x1._array, x2._array) - x1_ = np.moveaxis(x1_, axis, -1) - x2_ = np.moveaxis(x2_, axis, -1) - - res = x1_[..., None, :] @ x2_[..., None] - return Array._new(res[..., 0, 0]) - - # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf]]] but Literal does not support floating-point literals. +@requires_extension('linalg') def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm `. @@ -472,4 +439,35 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No return res +# These functions are also in the main namespace. We define them here as +# wrappers so that they can still be disabled when the linalg extension is +# disabled without disabling the versions in the main namespace. + +# Note: matmul is the numpy top-level namespace but not in np.linalg +@requires_extension('linalg') +def matmul(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.matmul `. + + See its docstring for more information. + """ + from ._linear_algebra_functions import matmul + return matmul(x1, x2) + +# Note: tensordot is the numpy top-level namespace but not in np.linalg +@requires_extension('linalg') +def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: + from ._linear_algebra_functions import tensordot + return tensordot(x1, x2, axes=axes) + +@requires_extension('linalg') +def matrix_transpose(x: Array, /) -> Array: + from ._linear_algebra_functions import matrix_transpose + return matrix_transpose(x) + +@requires_extension('linalg') +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + from ._linear_algebra_functions import vecdot + return vecdot(x1, x2, axis=axis) + __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm'] From 78d368d6ba3f3adb2f138b0a0d8cd8294db61f0a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 12 Apr 2024 17:06:17 -0600 Subject: [PATCH 016/252] Add tests for disabling linalg and fft extensions --- array_api_strict/tests/test_flags.py | 91 ++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index ede4b96..d3d957f 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -4,6 +4,8 @@ from .. import (asarray, unique_all, unique_counts, unique_inverse, unique_values, nonzero) +import array_api_strict as xp + import pytest @pytest.fixture(autouse=True) @@ -76,3 +78,92 @@ def test_data_dependent_shapes(): pytest.raises(RuntimeError, lambda: unique_values(a)) pytest.raises(RuntimeError, lambda: nonzero(a)) pytest.raises(RuntimeError, lambda: a[mask]) + +linalg_examples = { + 'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)), + 'cross': lambda: xp.linalg.cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])), + 'det': lambda: xp.linalg.det(xp.eye(3)), + 'diagonal': lambda: xp.linalg.diagonal(xp.eye(3)), + 'eigh': lambda: xp.linalg.eigh(xp.eye(3)), + 'eigvalsh': lambda: xp.linalg.eigvalsh(xp.eye(3)), + 'inv': lambda: xp.linalg.inv(xp.eye(3)), + 'matmul': lambda: xp.linalg.matmul(xp.eye(3), xp.eye(3)), + 'matrix_norm': lambda: xp.linalg.matrix_norm(xp.eye(3)), + 'matrix_power': lambda: xp.linalg.matrix_power(xp.eye(3), 2), + 'matrix_rank': lambda: xp.linalg.matrix_rank(xp.eye(3)), + 'matrix_transpose': lambda: xp.linalg.matrix_transpose(xp.eye(3)), + 'outer': lambda: xp.linalg.outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'pinv': lambda: xp.linalg.pinv(xp.eye(3)), + 'qr': lambda: xp.linalg.qr(xp.eye(3)), + 'slogdet': lambda: xp.linalg.slogdet(xp.eye(3)), + 'solve': lambda: xp.linalg.solve(xp.eye(3), xp.eye(3)), + 'svd': lambda: xp.linalg.svd(xp.eye(3)), + 'svdvals': lambda: xp.linalg.svdvals(xp.eye(3)), + 'tensordot': lambda: xp.linalg.tensordot(xp.eye(3), xp.eye(3)), + 'trace': lambda: xp.linalg.trace(xp.eye(3)), + 'vecdot': lambda: xp.linalg.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'vector_norm': lambda: xp.linalg.vector_norm(xp.asarray([1., 2., 3.])), +} + +assert set(linalg_examples) == set(xp.linalg.__all__) + +linalg_main_namespace_examples = { + 'matmul': lambda: xp.matmul(xp.eye(3), xp.eye(3)), + 'matrix_transpose': lambda: xp.matrix_transpose(xp.eye(3)), + 'tensordot': lambda: xp.tensordot(xp.eye(3), xp.eye(3)), + 'vecdot': lambda: xp.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), +} + +assert set(linalg_main_namespace_examples) == set(xp.__all__) & set(xp.linalg.__all__) + +@pytest.mark.parametrize('func_name', linalg_examples.keys()) +def test_linalg(func_name): + func = linalg_examples[func_name] + if func_name in linalg_main_namespace_examples: + main_namespace_func = linalg_main_namespace_examples[func_name] + else: + main_namespace_func = lambda: None + + # First make sure the example actually works + func() + main_namespace_func() + + set_array_api_strict_flags(enabled_extensions=()) + pytest.raises(RuntimeError, func) + main_namespace_func() + + set_array_api_strict_flags(enabled_extensions=('linalg',)) + func() + main_namespace_func() + +fft_examples = { + 'fft': lambda: xp.fft.fft(xp.asarray([0j, 1j, 0j, 0j])), + 'ifft': lambda: xp.fft.ifft(xp.asarray([0j, 1j, 0j, 0j])), + 'fftn': lambda: xp.fft.fftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'ifftn': lambda: xp.fft.ifftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'rfft': lambda: xp.fft.rfft(xp.asarray([0., 1., 0., 0.])), + 'irfft': lambda: xp.fft.irfft(xp.asarray([0j, 1j, 0j, 0j])), + 'rfftn': lambda: xp.fft.rfftn(xp.asarray([[0., 1.], [0., 0.]])), + 'irfftn': lambda: xp.fft.irfftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'hfft': lambda: xp.fft.hfft(xp.asarray([0j, 1j, 0j, 0j])), + 'ihfft': lambda: xp.fft.ihfft(xp.asarray([0., 1., 0., 0.])), + 'fftfreq': lambda: xp.fft.fftfreq(4), + 'rfftfreq': lambda: xp.fft.rfftfreq(4), + 'fftshift': lambda: xp.fft.fftshift(xp.asarray([0j, 1j, 0j, 0j])), + 'ifftshift': lambda: xp.fft.ifftshift(xp.asarray([0j, 1j, 0j, 0j])), +} + +assert set(fft_examples) == set(xp.fft.__all__) + +@pytest.mark.parametrize('func_name', fft_examples.keys()) +def test_fft(func_name): + func = fft_examples[func_name] + + # First make sure the example actually works + func() + + set_array_api_strict_flags(enabled_extensions=()) + pytest.raises(RuntimeError, func) + + set_array_api_strict_flags(enabled_extensions=('fft',)) + func() From c32a452864d10afbeebac31cf9d30858289a70fa Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 12 Apr 2024 23:37:06 -0600 Subject: [PATCH 017/252] Add docstring to __array_namespace__ --- array_api_strict/_array_object.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index e58767f..9150381 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -487,6 +487,20 @@ def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: def __array_namespace__( self: Array, /, *, api_version: Optional[str] = None ) -> types.ModuleType: + """ + Return the array_api_strict namespace corresponding to api_version. + + The default API version is '2022.12'. Note that '2021.12' is supported, + but currently identical to '2022.12'. + + For array_api_strict, calling this function with api_version will set + the API version for the array_api_strict module globally. This can + also be achieved with the + {func}`array_api_strict.set_array_api_strict_flags` function. If you + want some way to only set the version locally, use the + {class}`array_api_strict.ArrayApiStrictFlags` context manager. + + """ if api_version is not None and api_version not in ["2021.12", "2022.12"]: raise ValueError(f"Unrecognized array API version: {api_version!r}") if api_version == "2021.12": From 632900f92082f7701cb2cdd79b0f99dcf336649c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Apr 2024 12:52:57 -0600 Subject: [PATCH 018/252] Remove duplicate sentence --- array_api_strict/_flags.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index bbe2c59..13ba89f 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -62,12 +62,11 @@ def set_array_api_strict_flags( ``{default_version!r}``. - `data_dependent_shapes`: Whether data-dependent shapes are enabled in - array-api-strict. This flag is enabled by default. Array libraries that - use computation graphs may not be able to support functions whose output - shapes depend on the input data. + array-api-strict. - This flag is enabled by default. Array libraries that use computation graphs may not be able to support - functions whose output shapes depend on the input data. + This flag is enabled by default. Array libraries that use computation + graphs may not be able to support functions whose output shapes depend + on the input data. The functions that make use of data-dependent shapes, and are therefore disabled by setting this flag to False are From befd28c198ad010e2036221429c44ef3265085e8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Apr 2024 12:55:52 -0600 Subject: [PATCH 019/252] Set the api version flag in __array_namespace__ --- array_api_strict/_array_object.py | 5 ++--- array_api_strict/_flags.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 9150381..a2d68a5 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -32,7 +32,7 @@ _result_type, _dtype_categories, ) -from ._flags import get_array_api_strict_flags +from ._flags import get_array_api_strict_flags, set_array_api_strict_flags from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex import types @@ -501,8 +501,7 @@ def __array_namespace__( {class}`array_api_strict.ArrayApiStrictFlags` context manager. """ - if api_version is not None and api_version not in ["2021.12", "2022.12"]: - raise ValueError(f"Unrecognized array API version: {api_version!r}") + set_array_api_strict_flags(standard_version=api_version) if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") import array_api_strict diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 13ba89f..eb76289 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -116,7 +116,7 @@ def set_array_api_strict_flags( if standard_version is not None: if standard_version not in supported_versions: - raise ValueError(f"Unsupported standard version {standard_version}") + raise ValueError(f"Unsupported standard version {standard_version!r}") STANDARD_VERSION = standard_version if data_dependent_shapes is not None: From 319799ee31e0c3778f27c79ca299b0bddab11317 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Apr 2024 12:57:32 -0600 Subject: [PATCH 020/252] Rename the "standard_version" flag to "api_version" This matches the name used in __array_namespace__ --- array_api_strict/_array_object.py | 2 +- array_api_strict/_flags.py | 40 ++++++++++++++-------------- array_api_strict/tests/test_flags.py | 14 +++++----- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index a2d68a5..28d0eb9 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -501,7 +501,7 @@ def __array_namespace__( {class}`array_api_strict.ArrayApiStrictFlags` context manager. """ - set_array_api_strict_flags(standard_version=api_version) + set_array_api_strict_flags(api_version=api_version) if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") import array_api_strict diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index eb76289..d02476d 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -19,7 +19,7 @@ "2022.12", ) -STANDARD_VERSION = default_version = "2022.12" +API_VERSION = default_version = "2022.12" DATA_DEPENDENT_SHAPES = True @@ -42,7 +42,7 @@ def set_array_api_strict_flags( *, - standard_version=None, + api_version=None, data_dependent_shapes=None, enabled_extensions=None, ): @@ -57,7 +57,7 @@ def set_array_api_strict_flags( This function is **not** part of the array API standard. It only exists in array-api-strict. - - `standard_version`: The version of the standard to use. Supported + - `api_version`: The version of the standard to use. Supported versions are: ``{supported_versions}``. The default version number is ``{default_version!r}``. @@ -88,7 +88,7 @@ def set_array_api_strict_flags( The default values of the flags can also be changed by setting environment variables: - - ``ARRAY_API_STRICT_STANDARD_VERSION``: A string representing the version number. + - ``ARRAY_API_STRICT_API_VERSION``: A string representing the version number. - ``ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES``: "True" or "False". - ``ARRAY_API_STRICT_ENABLED_EXTENSIONS``: A comma separated list of extensions to enable. @@ -98,7 +98,7 @@ def set_array_api_strict_flags( >>> from array_api_strict import set_array_api_strict_flags >>> # Set the standard version to 2021.12 - >>> set_array_api_strict_flags(standard_version="2021.12") + >>> set_array_api_strict_flags(api_version="2021.12") >>> # Disable data-dependent shapes >>> set_array_api_strict_flags(data_dependent_shapes=False) >>> # Enable only the linalg extension (disable the fft extension) @@ -112,12 +112,12 @@ def set_array_api_strict_flags( ArrayApiStrictFlags: A context manager to temporarily set the flags. """ - global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS + global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS - if standard_version is not None: - if standard_version not in supported_versions: - raise ValueError(f"Unsupported standard version {standard_version!r}") - STANDARD_VERSION = standard_version + if api_version is not None: + if api_version not in supported_versions: + raise ValueError(f"Unsupported standard version {api_version!r}") + API_VERSION = api_version if data_dependent_shapes is not None: DATA_DEPENDENT_SHAPES = data_dependent_shapes @@ -126,14 +126,14 @@ def set_array_api_strict_flags( for extension in enabled_extensions: if extension not in all_extensions: raise ValueError(f"Unsupported extension {extension}") - if extension_versions[extension] > STANDARD_VERSION: + if extension_versions[extension] > API_VERSION: raise ValueError( f"Extension {extension} requires standard version " f"{extension_versions[extension]} or later" ) ENABLED_EXTENSIONS = tuple(enabled_extensions) else: - ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= STANDARD_VERSION]) + ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= API_VERSION]) # We have to do this separately or it won't get added as the docstring set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( @@ -162,7 +162,7 @@ def get_array_api_strict_flags(): >>> from array_api_strict import get_array_api_strict_flags >>> flags = get_array_api_strict_flags() >>> flags - {'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')} + {'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')} See Also -------- @@ -173,7 +173,7 @@ def get_array_api_strict_flags(): """ return { - "standard_version": STANDARD_VERSION, + "api_version": API_VERSION, "data_dependent_shapes": DATA_DEPENDENT_SHAPES, "enabled_extensions": ENABLED_EXTENSIONS, } @@ -204,8 +204,8 @@ def reset_array_api_strict_flags(): ArrayApiStrictFlags: A context manager to temporarily set the flags. """ - global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS - STANDARD_VERSION = default_version + global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS + API_VERSION = default_version DATA_DEPENDENT_SHAPES = True ENABLED_EXTENSIONS = default_extensions @@ -230,10 +230,10 @@ class ArrayApiStrictFlags: reset_array_api_strict_flags """ - def __init__(self, *, standard_version=None, data_dependent_shapes=None, + def __init__(self, *, api_version=None, data_dependent_shapes=None, enabled_extensions=None): self.kwargs = { - "standard_version": standard_version, + "api_version": api_version, "data_dependent_shapes": data_dependent_shapes, "enabled_extensions": enabled_extensions, } @@ -248,9 +248,9 @@ def __exit__(self, exc_type, exc_value, traceback): # Private functions def set_flags_from_environment(): - if "ARRAY_API_STRICT_STANDARD_VERSION" in os.environ: + if "ARRAY_API_STRICT_API_VERSION" in os.environ: set_array_api_strict_flags( - standard_version=os.environ["ARRAY_API_STRICT_STANDARD_VERSION"] + api_version=os.environ["ARRAY_API_STRICT_API_VERSION"] ) if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index d3d957f..99037a5 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -18,7 +18,7 @@ def test_flags(): # Test defaults flags = get_array_api_strict_flags() assert flags == { - 'standard_version': '2022.12', + 'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), } @@ -27,33 +27,33 @@ def test_flags(): set_array_api_strict_flags(data_dependent_shapes=False) flags = get_array_api_strict_flags() assert flags == { - 'standard_version': '2022.12', + 'api_version': '2022.12', 'data_dependent_shapes': False, 'enabled_extensions': ('linalg', 'fft'), } set_array_api_strict_flags(enabled_extensions=('fft',)) flags = get_array_api_strict_flags() assert flags == { - 'standard_version': '2022.12', + 'api_version': '2022.12', 'data_dependent_shapes': False, 'enabled_extensions': ('fft',), } # Make sure setting the version to 2021.12 disables fft - set_array_api_strict_flags(standard_version='2021.12') + set_array_api_strict_flags(api_version='2021.12') flags = get_array_api_strict_flags() assert flags == { - 'standard_version': '2021.12', + 'api_version': '2021.12', 'data_dependent_shapes': False, 'enabled_extensions': ('linalg',), } # Test setting flags with invalid values pytest.raises(ValueError, lambda: - set_array_api_strict_flags(standard_version='2020.12')) + set_array_api_strict_flags(api_version='2020.12')) pytest.raises(ValueError, lambda: set_array_api_strict_flags( enabled_extensions=('linalg', 'fft', 'invalid'))) pytest.raises(ValueError, lambda: set_array_api_strict_flags( - standard_version='2021.12', + api_version='2021.12', enabled_extensions=('linalg', 'fft'))) From b4bfb8d14e6aaac1efb5e2cf0efb02cf10462614 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Apr 2024 13:13:25 -0600 Subject: [PATCH 021/252] Move the reset_flags fixture to be global to all the tests --- array_api_strict/tests/conftest.py | 9 +++++++++ array_api_strict/tests/test_flags.py | 6 ------ 2 files changed, 9 insertions(+), 6 deletions(-) create mode 100644 array_api_strict/tests/conftest.py diff --git a/array_api_strict/tests/conftest.py b/array_api_strict/tests/conftest.py new file mode 100644 index 0000000..5000d5d --- /dev/null +++ b/array_api_strict/tests/conftest.py @@ -0,0 +1,9 @@ +from .._flags import reset_array_api_strict_flags + +import pytest + +@pytest.fixture(autouse=True) +def reset_flags(): + reset_array_api_strict_flags() + yield + reset_array_api_strict_flags() diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 99037a5..dff216a 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -8,12 +8,6 @@ import pytest -@pytest.fixture(autouse=True) -def reset_flags(): - reset_array_api_strict_flags() - yield - reset_array_api_strict_flags() - def test_flags(): # Test defaults flags = get_array_api_strict_flags() From 0d758ebac45c6a3e847d18e6cce31491fbd127db Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Apr 2024 13:13:50 -0600 Subject: [PATCH 022/252] Test reset_array_api_strict_flags() --- array_api_strict/tests/test_flags.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index dff216a..5e5e171 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -50,6 +50,18 @@ def test_flags(): api_version='2021.12', enabled_extensions=('linalg', 'fft'))) + # Test resetting flags + set_array_api_strict_flags( + api_version='2021.12', + data_dependent_shapes=False, + enabled_extensions=()) + reset_array_api_strict_flags() + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2022.12', + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } def test_data_dependent_shapes(): a = asarray([0, 0, 1, 2, 2]) From 0ba1267d89f5f5cc37fd673606c8902aae526860 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Apr 2024 13:14:09 -0600 Subject: [PATCH 023/252] Set __array_api_version__ with the api_version flag --- array_api_strict/__init__.py | 6 +++++- array_api_strict/_array_object.py | 2 +- array_api_strict/_flags.py | 5 ++++- array_api_strict/tests/test_array_object.py | 8 ++++++++ array_api_strict/tests/test_flags.py | 8 ++++++++ 5 files changed, 26 insertions(+), 3 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 31f0992..b323a65 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -16,7 +16,11 @@ """ -__array_api_version__ = "2022.12" +# Warning: __array_api_version__ could change globally with +# set_array_api_strict_flags(). This should always be accessed as an +# attribute, like xp.__array_api_version__, or using +# array_api_strict.get_array_api_strict_flags()['api_version']. +from ._flags import API_VERSION as __array_api_version__ __all__ = ["__array_api_version__"] diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 28d0eb9..ff9b8f8 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -497,7 +497,7 @@ def __array_namespace__( the API version for the array_api_strict module globally. This can also be achieved with the {func}`array_api_strict.set_array_api_strict_flags` function. If you - want some way to only set the version locally, use the + want to only set the version locally, use the {class}`array_api_strict.ArrayApiStrictFlags` context manager. """ diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index d02476d..80de965 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -14,6 +14,8 @@ import functools import os +import array_api_strict + supported_versions = ( "2021.12", "2022.12", @@ -37,7 +39,6 @@ "linalg", "fft", ) - # Public functions def set_array_api_strict_flags( @@ -118,6 +119,7 @@ def set_array_api_strict_flags( if api_version not in supported_versions: raise ValueError(f"Unsupported standard version {api_version!r}") API_VERSION = api_version + array_api_strict.__array_api_version__ = API_VERSION if data_dependent_shapes is not None: DATA_DEPENDENT_SHAPES = data_dependent_shapes @@ -206,6 +208,7 @@ def reset_array_api_strict_flags(): """ global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS API_VERSION = default_version + array_api_strict.__array_api_version__ = API_VERSION DATA_DEPENDENT_SHAPES = True ENABLED_EXTENSIONS = default_extensions diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index a0b6132..bae0553 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -402,9 +402,17 @@ def test_array_keys_use_private_array(): def test_array_namespace(): a = ones((3, 3)) assert a.__array_namespace__() == array_api_strict + assert array_api_strict.__array_api_version__ == "2022.12" + assert a.__array_namespace__(api_version=None) is array_api_strict + assert array_api_strict.__array_api_version__ == "2022.12" + assert a.__array_namespace__(api_version="2022.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2022.12" + with pytest.warns(UserWarning): assert a.__array_namespace__(api_version="2021.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2021.12" + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2023.12")) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 5e5e171..f9d8ad6 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -63,6 +63,14 @@ def test_flags(): 'enabled_extensions': ('linalg', 'fft'), } +def test_api_version(): + # Test defaults + assert xp.__array_api_version__ == '2022.12' + + # Test setting the version + set_array_api_strict_flags(api_version='2021.12') + assert xp.__array_api_version__ == '2021.12' + def test_data_dependent_shapes(): a = asarray([0, 0, 1, 2, 2]) mask = asarray([True, False, True, False, True]) From 30baeb7daa5197d0babb84cbed15cb43402cb34b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Apr 2024 16:16:44 -0600 Subject: [PATCH 024/252] Move warning about 2021.12 to set_array_api_strict_flags() --- array_api_strict/_array_object.py | 3 --- array_api_strict/_flags.py | 6 ++++++ array_api_strict/tests/test_flags.py | 17 +++++++++++------ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index ff9b8f8..2b9155a 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -17,7 +17,6 @@ import operator from enum import IntEnum -import warnings from ._creation_functions import asarray from ._dtypes import ( @@ -502,8 +501,6 @@ def __array_namespace__( """ set_array_api_strict_flags(api_version=api_version) - if api_version == "2021.12": - warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") import array_api_strict return array_api_strict diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 80de965..2114620 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -13,6 +13,7 @@ import functools import os +import warnings import array_api_strict @@ -62,6 +63,9 @@ def set_array_api_strict_flags( versions are: ``{supported_versions}``. The default version number is ``{default_version!r}``. + Note that 2021.12 is supported, but currently gives the same thing as + 2022.12 (except that the fft extension will be disabled). + - `data_dependent_shapes`: Whether data-dependent shapes are enabled in array-api-strict. @@ -118,6 +122,8 @@ def set_array_api_strict_flags( if api_version is not None: if api_version not in supported_versions: raise ValueError(f"Unsupported standard version {api_version!r}") + if api_version == "2021.12": + warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index f9d8ad6..303c930 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -32,8 +32,12 @@ def test_flags(): 'data_dependent_shapes': False, 'enabled_extensions': ('fft',), } - # Make sure setting the version to 2021.12 disables fft - set_array_api_strict_flags(api_version='2021.12') + # Make sure setting the version to 2021.12 disables fft and issues a + # warning. + with pytest.warns(UserWarning) as record: + set_array_api_strict_flags(api_version='2021.12') + assert len(record) == 1 + assert '2021.12' in str(record[0].message) flags = get_array_api_strict_flags() assert flags == { 'api_version': '2021.12', @@ -51,10 +55,11 @@ def test_flags(): enabled_extensions=('linalg', 'fft'))) # Test resetting flags - set_array_api_strict_flags( - api_version='2021.12', - data_dependent_shapes=False, - enabled_extensions=()) + with pytest.warns(UserWarning): + set_array_api_strict_flags( + api_version='2021.12', + data_dependent_shapes=False, + enabled_extensions=()) reset_array_api_strict_flags() flags = get_array_api_strict_flags() assert flags == { From 22352d2f29fc9581959860d6c04026c4d88ba86f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Apr 2024 16:35:13 -0600 Subject: [PATCH 025/252] Update some flags documentation --- array_api_strict/_flags.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 2114620..e0344d8 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -1,14 +1,15 @@ """ -This file defines flags for that allow array-api-strict to be used in -different "modes". These modes include +These functions configure global flags that allow array-api-strict to be +used in different "modes". These modes include - Changing to different supported versions of the standard. - Enabling or disabling different optional behaviors (such as data-dependent shapes). - Enabling or disabling different optional extensions. -Nothing in this file is part of the standard itself. A typical array API +None of these functions are part of the standard itself. A typical array API library will only support one particular configuration of these flags. + """ import functools @@ -112,8 +113,8 @@ def set_array_api_strict_flags( See Also -------- - get_array_api_strict_flags - reset_array_api_strict_flags + get_array_api_strict_flags: Get the current values of flags. + reset_array_api_strict_flags: Reset the flags to their default values. ArrayApiStrictFlags: A context manager to temporarily set the flags. """ @@ -175,8 +176,8 @@ def get_array_api_strict_flags(): See Also -------- - set_array_api_strict_flags - reset_array_api_strict_flags + set_array_api_strict_flags: Set one or more flags to a given value. + reset_array_api_strict_flags: Reset the flags to their default values. ArrayApiStrictFlags: A context manager to temporarily set the flags. """ @@ -207,8 +208,8 @@ def reset_array_api_strict_flags(): See Also -------- - set_array_api_strict_flags - get_array_api_strict_flags + get_array_api_strict_flags: Get the current values of flags. + set_array_api_strict_flags: Set one or more flags to a given value. ArrayApiStrictFlags: A context manager to temporarily set the flags. """ @@ -234,9 +235,9 @@ class ArrayApiStrictFlags: See Also -------- - set_array_api_strict_flags - get_array_api_strict_flags - reset_array_api_strict_flags + get_array_api_strict_flags: Get the current values of flags. + set_array_api_strict_flags: Set one or more flags to a given value. + reset_array_api_strict_flags: Reset the flags to their default values. """ def __init__(self, *, api_version=None, data_dependent_shapes=None, From 654865dd1ca57beddfb931c2cfced2169406cfe0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Apr 2024 16:36:23 -0600 Subject: [PATCH 026/252] Rename ArrayApiStrictFlags to ArrayAPIStrictFlags --- array_api_strict/__init__.py | 4 ++-- array_api_strict/_flags.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index b323a65..3f418d8 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -294,10 +294,10 @@ set_array_api_strict_flags, get_array_api_strict_flags, reset_array_api_strict_flags, - ArrayApiStrictFlags, + ArrayAPIStrictFlags, ) -__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayApiStrictFlags'] +__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayAPIStrictFlags'] from . import _version __version__ = _version.get_versions()['version'] diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index e0344d8..a07393a 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -115,7 +115,7 @@ def set_array_api_strict_flags( get_array_api_strict_flags: Get the current values of flags. reset_array_api_strict_flags: Reset the flags to their default values. - ArrayApiStrictFlags: A context manager to temporarily set the flags. + ArrayAPIStrictFlags: A context manager to temporarily set the flags. """ global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS @@ -178,7 +178,7 @@ def get_array_api_strict_flags(): set_array_api_strict_flags: Set one or more flags to a given value. reset_array_api_strict_flags: Reset the flags to their default values. - ArrayApiStrictFlags: A context manager to temporarily set the flags. + ArrayAPIStrictFlags: A context manager to temporarily set the flags. """ return { @@ -210,7 +210,7 @@ def reset_array_api_strict_flags(): get_array_api_strict_flags: Get the current values of flags. set_array_api_strict_flags: Set one or more flags to a given value. - ArrayApiStrictFlags: A context manager to temporarily set the flags. + ArrayAPIStrictFlags: A context manager to temporarily set the flags. """ global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS @@ -220,7 +220,7 @@ def reset_array_api_strict_flags(): ENABLED_EXTENSIONS = default_extensions -class ArrayApiStrictFlags: +class ArrayAPIStrictFlags: """ A context manager to temporarily set the array-api-strict flags. From 0011d23ade7695d043fda38c08f571a43b928f79 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Apr 2024 17:00:20 -0600 Subject: [PATCH 027/252] Add flags functions to Sphinx documentation --- array_api_strict/_flags.py | 20 ++++++++++--------- docs/api.rst | 39 ++++++++++++++++++++++++++++++++++++++ docs/changelog.md | 2 +- docs/conf.py | 4 ++-- docs/index.md | 2 ++ 5 files changed, 55 insertions(+), 12 deletions(-) create mode 100644 docs/api.rst diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index a07393a..6cc503a 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -91,22 +91,20 @@ def set_array_api_strict_flags( array-api-strict. The default is ``{default_extensions}``. Note that some extensions require a minimum version of the standard. - The default values of the flags can also be changed by setting environment - variables: - - - ``ARRAY_API_STRICT_API_VERSION``: A string representing the version number. - - ``ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES``: "True" or "False". - - ``ARRAY_API_STRICT_ENABLED_EXTENSIONS``: A comma separated list of - extensions to enable. + The flags can also be changed by setting :ref:`environment variables + `. Examples -------- >>> from array_api_strict import set_array_api_strict_flags + >>> # Set the standard version to 2021.12 >>> set_array_api_strict_flags(api_version="2021.12") + >>> # Disable data-dependent shapes >>> set_array_api_strict_flags(data_dependent_shapes=False) + >>> # Enable only the linalg extension (disable the fft extension) >>> set_array_api_strict_flags(enabled_extensions=["linalg"]) @@ -192,13 +190,17 @@ def reset_array_api_strict_flags(): """ Reset the array-api-strict flags to their default values. - This will also reset any flags that were set by environment variables. + This will also reset any flags that were set by :ref:`environment + variables ` back to their default values. .. note:: This function is **not** part of the array API standard. It only exists in array-api-strict. + See :func:`set_array_api_strict_flags` for a list of flags and their + default values. + Examples -------- @@ -229,7 +231,7 @@ class ArrayAPIStrictFlags: This class is **not** part of the array API standard. It only exists in array-api-strict. - See :func:`~.array_api_strict.set_array_api_strict_flags` for a + See :func:`set_array_api_strict_flags` for a description of the available flags. See Also diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..0827982 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,39 @@ +API Reference +============= + +.. automodule:: array_api_strict + +Array API Strict Flags +---------------------- + +.. automodule:: array_api_strict._flags + +.. currentmodule:: array_api_strict + +.. autofunction:: get_array_api_strict_flags +.. autofunction:: set_array_api_strict_flags +.. autofunction:: reset_array_api_strict_flags +.. autoclass:: ArrayAPIStrictFlags + +.. _environment-variables: + +Environment Variables +~~~~~~~~~~~~~~~~~~~~~ + +Flags can also be set with environment variables. +:func:`set_array_api_strict_flags` will override the values set by environment +variables. Note that the environment variables will only change the defaults +used by array-api-strict initially. They will not change the defaults used by +:func:`reset_array_api_strict_flags`. + +.. envvar:: ARRAY_API_STRICT_API_VERSION + + A string representing the version number. + +.. envvar:: ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES + + "True" or "False" to enable or disable data dependent shapes. + +.. envvar:: ARRAY_API_STRICT_ENABLED_EXTENSIONS + + A comma separated list of extensions to enable. diff --git a/docs/changelog.md b/docs/changelog.md index 8f1c203..04c383d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -20,7 +20,7 @@ This is the first release of `array_api_strict`. It is extracted from `numpy.array_api`, which was included as an experimental submodule in NumPy versions prior to 2.0. Note that the commit history in this repository is -extracted from the git history of numpy/array_api/ (see the [README](README.md)). +extracted from the git history of numpy/array_api/ (see [](numpy.array_api)). Additionally, the following changes are new to `array_api_strict` from `numpy.array_api` in NumPy 1.26 (the last NumPy feature release to include diff --git a/docs/conf.py b/docs/conf.py index c068b06..e4c66d7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,8 +22,8 @@ extensions = [ 'myst_parser', - # 'sphinx.ext.autodoc', - # 'sphinx.ext.napoleon', + 'sphinx.ext.autodoc', + 'sphinx.ext.napoleon', # 'sphinx.ext.intersphinx', 'sphinx_copybutton', ] diff --git a/docs/index.md b/docs/index.md index 307a9c2..6e84efa 100644 --- a/docs/index.md +++ b/docs/index.md @@ -183,6 +183,7 @@ issue, but this hasn't necessarily been tested thoroughly. API standard. [Support for 2023.12 is planned](https://github.com/data-apis/array-api-strict/issues/25). +(numpy.array_api)= ## Relationship to `numpy.array_api` Previously this implementation was available as `numpy.array_api`, but it was @@ -201,5 +202,6 @@ git_filter_repo.py --path numpy/array_api/ --path-rename numpy/array_api:array_a :titlesonly: :hidden: +api.rst changelog.md ``` From f92b4973251f1a90d56a6be5481cd6a91fd70413 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Apr 2024 17:02:54 -0600 Subject: [PATCH 028/252] Add a note about docs for standard functions --- docs/api.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/api.rst b/docs/api.rst index 0827982..e703a63 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -37,3 +37,12 @@ used by array-api-strict initially. They will not change the defaults used by .. envvar:: ARRAY_API_STRICT_ENABLED_EXTENSIONS A comma separated list of extensions to enable. + +Array API Functions +-------------------- + +All functions and methods in +the array API standard are implemented in array-api-strict. See the `Array API +Standard +`__ for +full documentation for each function. From bb1090b6bd31440adbe2a235fa7574f9ab2551d0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 13:31:46 -0600 Subject: [PATCH 029/252] Don't try to run the publish step unless the commit is a tag This prevents this workflow from failing on every commit in main. --- .github/workflows/publish-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 66f60fb..bfe98bb 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -69,7 +69,7 @@ jobs: publish: name: Publish Python distribution to (Test)PyPI - if: github.event_name != 'pull_request' && github.repository == 'data-apis/array-api-strict' + if: github.event_name != 'pull_request' && github.repository == 'data-apis/array-api-strict' && github.ref_type == 'tag' needs: build runs-on: ubuntu-latest # Mandatory for publishing with a trusted publisher From ac493c102fb2c65c2e0ab02bdf4f9e3e696e613c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 13:35:31 -0600 Subject: [PATCH 030/252] Increase the array-api-tests max-examples --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index a2ad2fd..bfb7dcf 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -3,7 +3,7 @@ name: Array API Tests on: [push, pull_request] env: - PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline" + PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline --max-examples 200" jobs: array-api-tests: From e61b50d27200e31fd90c766c1e8e9f47a19827e3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 13:53:44 -0600 Subject: [PATCH 031/252] Add a @requires_api_version decorator --- array_api_strict/_flags.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 6cc503a..9c87229 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -277,6 +277,21 @@ def set_flags_from_environment(): set_flags_from_environment() +# Decorators + +def requires_api_version(version): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if version > API_VERSION: + raise RuntimeError( + f"The function {func.__name__} requires API version {version} or later, " + f"but the current API version for array-api-strict is {API_VERSION}" + ) + return func(*args, **kwargs) + return wrapper + return decorator + def requires_data_dependent_shapes(func): @functools.wraps(func) def wrapper(*args, **kwargs): From f49845aa3e27af808b5f1fd1d0afc724143dc638 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 13:59:58 -0600 Subject: [PATCH 032/252] Don't re-enable disabled extensions when setting the api version --- array_api_strict/_flags.py | 2 +- array_api_strict/tests/test_flags.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 9c87229..cd33290 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -140,7 +140,7 @@ def set_array_api_strict_flags( ) ENABLED_EXTENSIONS = tuple(enabled_extensions) else: - ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= API_VERSION]) + ENABLED_EXTENSIONS = tuple([ext for ext in ENABLED_EXTENSIONS if extension_versions[ext] <= API_VERSION]) # We have to do this separately or it won't get added as the docstring set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 303c930..996a684 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -42,6 +42,16 @@ def test_flags(): assert flags == { 'api_version': '2021.12', 'data_dependent_shapes': False, + 'enabled_extensions': (), + } + reset_array_api_strict_flags() + + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2021.12') + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2021.12', + 'data_dependent_shapes': True, 'enabled_extensions': ('linalg',), } From 71c523198b8f15d9f3920237b191cc9fc7efe710 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 14:01:43 -0600 Subject: [PATCH 033/252] Add support for setting the api version to 2023.12 --- array_api_strict/_flags.py | 6 ++++++ array_api_strict/tests/test_array_object.py | 5 ++++- array_api_strict/tests/test_flags.py | 13 +++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index cd33290..205c325 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -21,6 +21,7 @@ supported_versions = ( "2021.12", "2022.12", + "2023.12", ) API_VERSION = default_version = "2022.12" @@ -67,6 +68,9 @@ def set_array_api_strict_flags( Note that 2021.12 is supported, but currently gives the same thing as 2022.12 (except that the fft extension will be disabled). + 2023.12 support is preliminary. Some features in 2023.12 may still be + missing, and it hasn't been fully tested. + - `data_dependent_shapes`: Whether data-dependent shapes are enabled in array-api-strict. @@ -123,6 +127,8 @@ def set_array_api_strict_flags( raise ValueError(f"Unsupported standard version {api_version!r}") if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") + if api_version == "2023.12": + warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.") API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index bae0553..9d9dad0 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -410,9 +410,12 @@ def test_array_namespace(): assert a.__array_namespace__(api_version="2022.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2022.12" + assert a.__array_namespace__(api_version="2023.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2023.12" + with pytest.warns(UserWarning): assert a.__array_namespace__(api_version="2021.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2021.12" pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) - pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2023.12")) + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 996a684..b1ad61f 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -54,6 +54,19 @@ def test_flags(): 'data_dependent_shapes': True, 'enabled_extensions': ('linalg',), } + reset_array_api_strict_flags() + + # 2023.12 should issue a warning + with pytest.warns(UserWarning) as record: + set_array_api_strict_flags(api_version='2023.12') + assert len(record) == 1 + assert '2023.12' in str(record[0].message) + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2023.12', + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } # Test setting flags with invalid values pytest.raises(ValueError, lambda: From c39fdbfcea9c755b72f6387c92c3f8c8f1a7acdd Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 15:48:25 -0600 Subject: [PATCH 034/252] Set the stacklevel in the set_array_api_strict_flags() warnings --- array_api_strict/_flags.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 205c325..476ffb9 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -126,9 +126,9 @@ def set_array_api_strict_flags( if api_version not in supported_versions: raise ValueError(f"Unsupported standard version {api_version!r}") if api_version == "2021.12": - warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") + warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2) if api_version == "2023.12": - warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.") + warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.", stacklevel=2) API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION From 31c5a89902508087cdaec663631fd51a1f95aeb6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 15:52:32 -0600 Subject: [PATCH 035/252] Add clip() It is only enabled for when the api version is 2023.12. I have only tested that it works manually. There is no test suite support for clip() yet. --- array_api_strict/__init__.py | 2 + array_api_strict/_elementwise_functions.py | 67 ++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 3f418d8..6a9079a 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -134,6 +134,7 @@ bitwise_right_shift, bitwise_xor, ceil, + clip, conj, cos, cosh, @@ -196,6 +197,7 @@ "bitwise_right_shift", "bitwise_xor", "ceil", + "clip", "cos", "cosh", "divide", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 8b69677..800ee70 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -12,6 +12,11 @@ _result_type, ) from ._array_object import Array +from ._flags import requires_api_version +from ._creation_functions import asarray +from ._utility_functions import any as xp_any + +from typing import Optional, Union import numpy as np @@ -240,6 +245,68 @@ def ceil(x: Array, /) -> Array: return x return Array._new(np.ceil(x._array)) +# WARNING: This function is not yet tested by the array-api-tests test suite. + +# Note: min and max argument names are different and not optional in numpy. +@requires_api_version('2023.12') +def clip( + x: Array, + /, + min: Optional[Union[int, float, Array]] = None, + max: Optional[Union[int, float, Array]] = None, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.clip `. + + See its docstring for more information. + """ + if (x.dtype not in _real_numeric_dtypes + or isinstance(min, Array) and min.dtype not in _real_numeric_dtypes + or isinstance(max, Array) and max.dtype not in _real_numeric_dtypes): + raise TypeError("Only real numeric dtypes are allowed in clip") + if not isinstance(min, (int, float, Array, type(None))): + raise TypeError("min must be an None, int, float, or an array") + if not isinstance(max, (int, float, Array, type(None))): + raise TypeError("max must be an None, int, float, or an array") + + # Mixed dtype kinds is implementation defined + if (x.dtype in _integer_dtypes + and (isinstance(min, float) or + isinstance(min, Array) and min.dtype in _real_floating_dtypes)): + raise TypeError("min must be integral when x is integral") + if (x.dtype in _integer_dtypes + and (isinstance(max, float) or + isinstance(max, Array) and max.dtype in _real_floating_dtypes)): + raise TypeError("max must be integral when x is integral") + if (x.dtype in _real_floating_dtypes + and (isinstance(min, int) or + isinstance(min, Array) and min.dtype in _integer_dtypes)): + raise TypeError("min must be floating-point when x is floating-point") + if (x.dtype in _real_floating_dtypes + and (isinstance(max, int) or + isinstance(max, Array) and max.dtype in _integer_dtypes)): + raise TypeError("max must be floating-point when x is floating-point") + + if min is max is None: + # Note: NumPy disallows min = max = None + return x + + # Normalize to make the below logic simpler + if min is not None: + min = asarray(min)._array + if max is not None: + max = asarray(max)._array + + # min > max is implementation defined + if min is not None and max is not None and np.any(min > max): + raise ValueError("min must be less than or equal to max") + + result = np.clip(x._array, min, max) + # Note: NumPy applies type promotion, but the standard specifies the + # return dtype should be the same as x + if result.dtype != x.dtype._np_dtype: + result = result.astype(x.dtype._np_dtype) + return Array._new(result) def conj(x: Array, /) -> Array: """ From 6a6719ff6115e5bde084682cc5870e012f0a70c5 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 15:58:49 -0600 Subject: [PATCH 036/252] Add a TODO note for clip() --- array_api_strict/_elementwise_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 800ee70..c9272bb 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -305,6 +305,8 @@ def clip( # Note: NumPy applies type promotion, but the standard specifies the # return dtype should be the same as x if result.dtype != x.dtype._np_dtype: + # TODO: I'm not completely sure this always gives the correct thing + # for integer dtypes. See https://github.com/numpy/numpy/issues/24976 result = result.astype(x.dtype._np_dtype) return Array._new(result) From 4e0102b92beb6928bc4ce99cf3af7f39b2188f64 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 16:07:23 -0600 Subject: [PATCH 037/252] Add missing names to __all__ --- array_api_strict/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 3f418d8..f4f2b39 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -83,6 +83,7 @@ "broadcast_to", "can_cast", "finfo", + "isdtype", "iinfo", "result_type", ] @@ -114,6 +115,8 @@ "uint64", "float32", "float64", + "complex64", + "complex128", "bool", ] @@ -196,6 +199,7 @@ "bitwise_right_shift", "bitwise_xor", "ceil", + "conj", "cos", "cosh", "divide", @@ -206,6 +210,7 @@ "floor_divide", "greater", "greater_equal", + "imag", "isfinite", "isinf", "isnan", @@ -225,6 +230,7 @@ "not_equal", "positive", "pow", + "real", "remainder", "round", "sign", From edffda7ebabf126e652f504e4aba424f263d1dff Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 16:14:52 -0600 Subject: [PATCH 038/252] Set up ruff linting --- .github/workflows/ruff.yml | 19 +++++++++++++++++++ ruff.toml | 7 +++++++ 2 files changed, 26 insertions(+) create mode 100644 .github/workflows/ruff.yml create mode 100644 ruff.toml diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..a9f0fd4 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,19 @@ +name: CI +on: [push, pull_request] +jobs: + check-ruff: + runs-on: ubuntu-latest + continue-on-error: true + steps: + - uses: actions/checkout@v4 + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff + # Update output format to enable automatic inline annotations. + - name: Run Ruff + run: ruff check --output-format=github . diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..43830f2 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,7 @@ +[lint] +ignore = [ + # Ignore module import not at top of file + "E402", + # Annoying style checks + "E7", +] From ef61f10041d3ceeb72a162adc2407ce0c3f09245 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 16:15:09 -0600 Subject: [PATCH 039/252] Fix some ruff unused import errors --- array_api_strict/_data_type_functions.py | 1 - array_api_strict/_statistical_functions.py | 2 +- array_api_strict/_typing.py | 4 ---- array_api_strict/linalg.py | 1 - array_api_strict/tests/test_array_object.py | 2 +- array_api_strict/tests/test_manipulation_functions.py | 1 - 6 files changed, 2 insertions(+), 9 deletions(-) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 16ff7b4..434eb45 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: from ._typing import Dtype - from collections.abc import Sequence import numpy as np diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 75e29f9..11ef4c6 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -6,7 +6,7 @@ _numeric_dtypes, ) from ._array_object import Array -from ._dtypes import float32, float64, complex64, complex128 +from ._dtypes import float32, complex64 from typing import TYPE_CHECKING, Optional, Tuple, Union diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index bc2f4df..ce25d4c 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -21,10 +21,6 @@ from typing import ( Any, - Literal, - Sequence, - Type, - Union, TypeVar, Protocol, ) diff --git a/array_api_strict/linalg.py b/array_api_strict/linalg.py index e1998fa..c350758 100644 --- a/array_api_strict/linalg.py +++ b/array_api_strict/linalg.py @@ -4,7 +4,6 @@ _floating_dtypes, _numeric_dtypes, float32, - float64, complex64, complex128, ) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index bae0553..e061a94 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from .. import ones, asarray, reshape, result_type, all, equal +from .. import ones, asarray, result_type, all, equal from .._array_object import Array, CPU_DEVICE from .._dtypes import ( _all_dtypes, diff --git a/array_api_strict/tests/test_manipulation_functions.py b/array_api_strict/tests/test_manipulation_functions.py index aec57c3..70b42f3 100644 --- a/array_api_strict/tests/test_manipulation_functions.py +++ b/array_api_strict/tests/test_manipulation_functions.py @@ -1,7 +1,6 @@ from numpy.testing import assert_raises import numpy as np -from .. import all from .._creation_functions import asarray from .._dtypes import float64, int8 from .._manipulation_functions import ( From 7189e2ae230cc0aba2085bed987c2eecd7e156dc Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 16:20:44 -0600 Subject: [PATCH 040/252] Fix or ignore ruff errors --- array_api_strict/_array_object.py | 6 +++--- array_api_strict/_indexing_functions.py | 2 ++ array_api_strict/_manipulation_functions.py | 2 +- array_api_strict/linalg.py | 4 ++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 2b9155a..89b3550 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -589,8 +589,8 @@ def __getitem__( key: Union[ int, slice, - ellipsis, - Tuple[Union[int, slice, ellipsis, None], ...], + ellipsis, # noqa: F821 + Tuple[Union[int, slice, ellipsis, None], ...], # noqa: F821 Array, ], /, @@ -780,7 +780,7 @@ def __rshift__(self: Array, other: Union[int, Array], /) -> Array: def __setitem__( self, key: Union[ - int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array + int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array # noqa: F821 ], value: Union[int, float, bool, Array], /, diff --git a/array_api_strict/_indexing_functions.py b/array_api_strict/_indexing_functions.py index baf23f7..7119cb4 100644 --- a/array_api_strict/_indexing_functions.py +++ b/array_api_strict/_indexing_functions.py @@ -3,6 +3,8 @@ from ._array_object import Array from ._dtypes import _integer_dtypes +from typing import Optional + import numpy as np def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 25a2754..c7abf32 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -57,7 +57,7 @@ def reshape(x: Array, /, shape: Tuple[int, ...], *, - copy: Optional[Bool] = None) -> Array: + copy: Optional[bool] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.reshape `. diff --git a/array_api_strict/linalg.py b/array_api_strict/linalg.py index c350758..1f548f0 100644 --- a/array_api_strict/linalg.py +++ b/array_api_strict/linalg.py @@ -159,7 +159,7 @@ def inv(x: Array, /) -> Array: # -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point # literals. @requires_extension('linalg') -def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array: +def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array: # noqa: F821 """ Array API compatible wrapper for :py:func:`np.linalg.norm `. @@ -251,7 +251,7 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: return Array._new(np.linalg.pinv(x._array, rcond=rtol)) @requires_extension('linalg') -def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: +def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: # noqa: F821 """ Array API compatible wrapper for :py:func:`np.linalg.qr `. From edd87f0d3510eb9a2383d625c7a6f88e49429638 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 16:25:58 -0600 Subject: [PATCH 041/252] Move all typing imports under "if TYPE_CHECKING" --- array_api_strict/_array_object.py | 5 +++-- array_api_strict/_data_type_functions.py | 3 ++- array_api_strict/_indexing_functions.py | 5 ++++- array_api_strict/_manipulation_functions.py | 5 ++++- array_api_strict/_searching_functions.py | 4 +++- array_api_strict/_statistical_functions.py | 3 ++- array_api_strict/_utility_functions.py | 4 +++- array_api_strict/fft.py | 3 ++- 8 files changed, 23 insertions(+), 9 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 89b3550..f77c38d 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -33,11 +33,12 @@ ) from ._flags import get_array_api_strict_flags, set_array_api_strict_flags -from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex +from typing import TYPE_CHECKING, SupportsIndex import types if TYPE_CHECKING: - from ._typing import Any, PyCapsule, Device, Dtype + from typing import Optional, Tuple, Union, Any + from ._typing import PyCapsule, Device, Dtype import numpy.typing as npt import numpy as np diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 434eb45..41f70c5 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -15,9 +15,10 @@ ) from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Tuple, Union +from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import List, Tuple, Union from ._typing import Dtype import numpy as np diff --git a/array_api_strict/_indexing_functions.py b/array_api_strict/_indexing_functions.py index 7119cb4..316a3a7 100644 --- a/array_api_strict/_indexing_functions.py +++ b/array_api_strict/_indexing_functions.py @@ -3,7 +3,10 @@ from ._array_object import Array from ._dtypes import _integer_dtypes -from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Optional import numpy as np diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index c7abf32..af9a3dd 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -3,7 +3,10 @@ from ._array_object import Array from ._data_type_functions import result_type -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import List, Optional, Tuple, Union import numpy as np diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 9781531..1ef2556 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -4,7 +4,9 @@ from ._dtypes import _result_type, _real_numeric_dtypes from ._flags import requires_data_dependent_shapes -from typing import Optional, Tuple +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Tuple import numpy as np diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 11ef4c6..cbe9d0d 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -8,9 +8,10 @@ from ._array_object import Array from ._dtypes import float32, complex64 -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Optional, Tuple, Union from ._typing import Dtype import numpy as np diff --git a/array_api_strict/_utility_functions.py b/array_api_strict/_utility_functions.py index 5ecb4bd..c91fa58 100644 --- a/array_api_strict/_utility_functions.py +++ b/array_api_strict/_utility_functions.py @@ -2,7 +2,9 @@ from ._array_object import Array -from typing import Optional, Tuple, Union +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Tuple, Union import numpy as np diff --git a/array_api_strict/fft.py b/array_api_strict/fft.py index 7f427e5..32b9551 100644 --- a/array_api_strict/fft.py +++ b/array_api_strict/fft.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union, Optional, Literal +from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Union, Optional, Literal from ._typing import Device from collections.abc import Sequence From 77e6177eecf4b330101965bf4f93ae18da01041e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 16:30:13 -0600 Subject: [PATCH 042/252] Remove unused import --- array_api_strict/_elementwise_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index c9272bb..ea52d96 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -14,7 +14,6 @@ from ._array_object import Array from ._flags import requires_api_version from ._creation_functions import asarray -from ._utility_functions import any as xp_any from typing import Optional, Union From 04c24d72be76305a19fb1de4153961b3848330aa Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 16:35:29 -0600 Subject: [PATCH 043/252] Add copysign copysign is not tested yet by the test suite, but the standard does not appear to deviate from NumPy (except in the restriction to floating-point dtypes). --- array_api_strict/__init__.py | 2 ++ array_api_strict/_elementwise_functions.py | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index e2212f1..9d9aca6 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -139,6 +139,7 @@ ceil, clip, conj, + copysign, cos, cosh, divide, @@ -202,6 +203,7 @@ "ceil", "clip", "conj", + "copysign", "cos", "cosh", "divide", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index ea52d96..994bcb2 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -319,6 +319,16 @@ def conj(x: Array, /) -> Array: raise TypeError("Only complex floating-point dtypes are allowed in conj") return Array._new(np.conj(x)) +@requires_api_version('2023.12') +def copysign(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.copysign `. + + See its docstring for more information. + """ + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in copysign") + return Array._new(np.copysign(x1._array, x2._array)) def cos(x: Array, /) -> Array: """ From c4587a4654db065ef8dba7e9115f43cacc871774 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 19 Apr 2024 23:38:11 -0600 Subject: [PATCH 044/252] Implement cumulative_sum (still needs to be tested) --- array_api_strict/__init__.py | 4 ++-- array_api_strict/_statistical_functions.py | 25 ++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 9d9aca6..c1fba30 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -290,9 +290,9 @@ __all__ += ["argsort", "sort"] -from ._statistical_functions import max, mean, min, prod, std, sum, var +from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var -__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"] +__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] from ._utility_functions import all, any diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index cbe9d0d..c65f50e 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -7,6 +7,9 @@ ) from ._array_object import Array from ._dtypes import float32, complex64 +from ._flags import requires_api_version +from ._creation_functions import zeros +from ._manipulation_functions import concat from typing import TYPE_CHECKING @@ -16,6 +19,28 @@ import numpy as np +@requires_api_version('2023.12') +def cumulative_sum( + x: Array, + /, + *, + axis: Optional[int] = None, + dtype: Optional[Dtype] = None, + include_initial: bool = False, +) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in cumulative_sum") + if dtype is None: + dtype = x.dtype + + if axis is None: + if x.ndim > 1: + raise ValueError("axis must be specified in cumulative_sum for more than one dimension") + axis = 0 + # np.cumsum does not support include_initial + if include_initial: + x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dtype), x], axis=axis) + return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype._np_dtype)) def max( x: Array, From 16b38d305a160e3d4825b65aea9b3ad2c5b54b48 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 22 Apr 2024 14:53:58 -0600 Subject: [PATCH 045/252] Add a comment about cumulative_sum and 0-D inputs --- array_api_strict/_statistical_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index c65f50e..b35d26f 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -33,6 +33,7 @@ def cumulative_sum( if dtype is None: dtype = x.dtype + # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: if x.ndim > 1: raise ValueError("axis must be specified in cumulative_sum for more than one dimension") From b689d43b9906b5e9091c23f628a0a7bfee7654d1 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 22 Apr 2024 15:18:46 -0600 Subject: [PATCH 046/252] Add hypot() This is untested, but the NumPy hypot() should match the standard. --- array_api_strict/__init__.py | 2 ++ array_api_strict/_elementwise_functions.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index c1fba30..b9df986 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -150,6 +150,7 @@ floor_divide, greater, greater_equal, + hypot, imag, isfinite, isinf, @@ -214,6 +215,7 @@ "floor_divide", "greater", "greater_equal", + "hypot", "imag", "isfinite", "isinf", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 994bcb2..f144c69 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -455,6 +455,18 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater_equal(x1._array, x2._array)) +def hypot(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.hypot `. + + See its docstring for more information. + """ + if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in hypot") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.hypot(x1._array, x2._array)) def imag(x: Array, /) -> Array: """ From 9ee08c7252e5e924755da043d9a4a3bd2c80b9bc Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 22 Apr 2024 15:26:46 -0600 Subject: [PATCH 047/252] Update elementwise tests for new elementwise functions Also add a meta-test to ensure the elementwise tests stay up-to-date. --- .../tests/test_elementwise_functions.py | 140 ++++++++++-------- 1 file changed, 76 insertions(+), 64 deletions(-) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 1228d0a..abb02f8 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,4 +1,4 @@ -from inspect import getfullargspec +from inspect import getfullargspec, getmodule from numpy.testing import assert_raises @@ -10,79 +10,88 @@ _floating_dtypes, _integer_dtypes, ) - +from .._flags import set_array_api_strict_flags def nargs(func): return len(getfullargspec(func).args) +elementwise_function_input_types = { + "abs": "numeric", + "acos": "floating-point", + "acosh": "floating-point", + "add": "numeric", + "asin": "floating-point", + "asinh": "floating-point", + "atan": "floating-point", + "atan2": "real floating-point", + "atanh": "floating-point", + "bitwise_and": "integer or boolean", + "bitwise_invert": "integer or boolean", + "bitwise_left_shift": "integer", + "bitwise_or": "integer or boolean", + "bitwise_right_shift": "integer", + "bitwise_xor": "integer or boolean", + "ceil": "real numeric", + "clip": "real numeric", + "conj": "complex floating-point", + "copysign": "real floating-point", + "cos": "floating-point", + "cosh": "floating-point", + "divide": "floating-point", + "equal": "all", + "exp": "floating-point", + "expm1": "floating-point", + "floor": "real numeric", + "floor_divide": "real numeric", + "greater": "real numeric", + "greater_equal": "real numeric", + "hypot": "real floating-point", + "imag": "complex floating-point", + "isfinite": "numeric", + "isinf": "numeric", + "isnan": "numeric", + "less": "real numeric", + "less_equal": "real numeric", + "log": "floating-point", + "logaddexp": "real floating-point", + "log10": "floating-point", + "log1p": "floating-point", + "log2": "floating-point", + "logical_and": "boolean", + "logical_not": "boolean", + "logical_or": "boolean", + "logical_xor": "boolean", + "multiply": "numeric", + "negative": "numeric", + "not_equal": "all", + "positive": "numeric", + "pow": "numeric", + "real": "complex floating-point", + "remainder": "real numeric", + "round": "numeric", + "sign": "numeric", + "sin": "floating-point", + "sinh": "floating-point", + "sqrt": "floating-point", + "square": "numeric", + "subtract": "numeric", + "tan": "floating-point", + "tanh": "floating-point", + "trunc": "real numeric", +} + +def test_missing_functions(): + # Ensure the above dictionary is complete. + import array_api_strict._elementwise_functions as mod + mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod] + assert set(mod_funcs) == set(elementwise_function_input_types) + def test_function_types(): # Test that every function accepts only the required input types. We only # test the negative cases here (error). The positive cases are tested in # the array API test suite. - elementwise_function_input_types = { - "abs": "numeric", - "acos": "floating-point", - "acosh": "floating-point", - "add": "numeric", - "asin": "floating-point", - "asinh": "floating-point", - "atan": "floating-point", - "atan2": "real floating-point", - "atanh": "floating-point", - "bitwise_and": "integer or boolean", - "bitwise_invert": "integer or boolean", - "bitwise_left_shift": "integer", - "bitwise_or": "integer or boolean", - "bitwise_right_shift": "integer", - "bitwise_xor": "integer or boolean", - "ceil": "real numeric", - "conj": "complex floating-point", - "cos": "floating-point", - "cosh": "floating-point", - "divide": "floating-point", - "equal": "all", - "exp": "floating-point", - "expm1": "floating-point", - "floor": "real numeric", - "floor_divide": "real numeric", - "greater": "real numeric", - "greater_equal": "real numeric", - "imag": "complex floating-point", - "isfinite": "numeric", - "isinf": "numeric", - "isnan": "numeric", - "less": "real numeric", - "less_equal": "real numeric", - "log": "floating-point", - "logaddexp": "real floating-point", - "log10": "floating-point", - "log1p": "floating-point", - "log2": "floating-point", - "logical_and": "boolean", - "logical_not": "boolean", - "logical_or": "boolean", - "logical_xor": "boolean", - "multiply": "numeric", - "negative": "numeric", - "not_equal": "all", - "positive": "numeric", - "pow": "numeric", - "real": "complex floating-point", - "remainder": "real numeric", - "round": "numeric", - "sign": "numeric", - "sin": "floating-point", - "sinh": "floating-point", - "sqrt": "floating-point", - "square": "numeric", - "subtract": "numeric", - "tan": "floating-point", - "tanh": "floating-point", - "trunc": "real numeric", - } - def _array_vals(): for d in _integer_dtypes: yield asarray(1, dtype=d) @@ -91,6 +100,9 @@ def _array_vals(): for d in _floating_dtypes: yield asarray(1.0, dtype=d) + # Use the latest version of the standard so all functions are included + set_array_api_strict_flags(api_version="2023.12") + for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] From e24f55ea9a6cef94e4c025429f29f098db353b40 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 22 Apr 2024 15:32:52 -0600 Subject: [PATCH 048/252] Clear trailing whitespace --- array_api_strict/tests/test_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/tests/test_manipulation_functions.py b/array_api_strict/tests/test_manipulation_functions.py index 70b42f3..9969651 100644 --- a/array_api_strict/tests/test_manipulation_functions.py +++ b/array_api_strict/tests/test_manipulation_functions.py @@ -25,7 +25,7 @@ def test_reshape_copy(): a = asarray(np.ones((2, 3))) b = reshape(a, (3, 2), copy=True) assert not np.shares_memory(a._array, b._array) - + a = asarray(np.ones((2, 3))) b = reshape(a, (3, 2), copy=False) assert np.shares_memory(a._array, b._array) From f5fbf78055d0c9d677a169ef2d5a82dc214b0430 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 22 Apr 2024 15:39:13 -0600 Subject: [PATCH 049/252] Silence warnings output in the tests --- array_api_strict/tests/test_array_object.py | 3 ++- array_api_strict/tests/test_elementwise_functions.py | 5 ++++- array_api_strict/tests/test_flags.py | 10 ++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 24fcf57..a66637f 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -410,7 +410,8 @@ def test_array_namespace(): assert a.__array_namespace__(api_version="2022.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2022.12" - assert a.__array_namespace__(api_version="2023.12") is array_api_strict + with pytest.warns(UserWarning): + assert a.__array_namespace__(api_version="2023.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2023.12" with pytest.warns(UserWarning): diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index abb02f8..3bfcbae 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -12,6 +12,8 @@ ) from .._flags import set_array_api_strict_flags +import pytest + def nargs(func): return len(getfullargspec(func).args) @@ -101,7 +103,8 @@ def _array_vals(): yield asarray(1.0, dtype=d) # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2023.12") + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2023.12") for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index b1ad61f..f6fbc0d 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -73,9 +73,10 @@ def test_flags(): set_array_api_strict_flags(api_version='2020.12')) pytest.raises(ValueError, lambda: set_array_api_strict_flags( enabled_extensions=('linalg', 'fft', 'invalid'))) - pytest.raises(ValueError, lambda: set_array_api_strict_flags( - api_version='2021.12', - enabled_extensions=('linalg', 'fft'))) + with pytest.warns(UserWarning): + pytest.raises(ValueError, lambda: set_array_api_strict_flags( + api_version='2021.12', + enabled_extensions=('linalg', 'fft'))) # Test resetting flags with pytest.warns(UserWarning): @@ -96,7 +97,8 @@ def test_api_version(): assert xp.__array_api_version__ == '2022.12' # Test setting the version - set_array_api_strict_flags(api_version='2021.12') + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2021.12') assert xp.__array_api_version__ == '2021.12' def test_data_dependent_shapes(): From 3e2d46de96b6c87f5a2552a45e6aed5f8ab7a882 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 23 Apr 2024 16:27:32 -0600 Subject: [PATCH 050/252] Add missing requires_api_version decorator to hypot() --- array_api_strict/_elementwise_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index f144c69..d1e589b 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -455,6 +455,7 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater_equal(x1._array, x2._array)) +@requires_api_version('2023.12') def hypot(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.hypot `. From 250ba869fc786e61c5c5ad3ef3bf66dc1cd24b71 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 23 Apr 2024 16:27:50 -0600 Subject: [PATCH 051/252] Add maximum and minimum --- array_api_strict/__init__.py | 4 +++ array_api_strict/_elementwise_functions.py | 29 +++++++++++++++++++ .../tests/test_elementwise_functions.py | 2 ++ 3 files changed, 35 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index b9df986..3c0e147 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -166,6 +166,8 @@ logical_not, logical_or, logical_xor, + maximum, + minimum, multiply, negative, not_equal, @@ -231,6 +233,8 @@ "logical_not", "logical_or", "logical_xor", + "maximum", + "minimum", "multiply", "negative", "not_equal", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index d1e589b..a82818b 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -651,6 +651,35 @@ def logical_xor(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.logical_xor(x1._array, x2._array)) +@requires_api_version('2023.12') +def maximum(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.maximum `. + + See its docstring for more information. + """ + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in maximum") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + # TODO: maximum(-0., 0.) is unspecified. Should we issue a warning/error + # in that case? + return Array._new(np.maximum(x1._array, x2._array)) + +@requires_api_version('2023.12') +def minimum(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.minimum `. + + See its docstring for more information. + """ + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in minimum") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.minimum(x1._array, x2._array)) def multiply(x1: Array, x2: Array, /) -> Array: """ diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 3bfcbae..6b4a5ec 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -64,6 +64,8 @@ def nargs(func): "logical_not": "boolean", "logical_or": "boolean", "logical_xor": "boolean", + "maximum": "real numeric", + "minimum": "real numeric", "multiply": "numeric", "negative": "numeric", "not_equal": "all", From eb063e2580a7abe65b2cdd92738663f2b4fb84ae Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 23 Apr 2024 16:35:16 -0600 Subject: [PATCH 052/252] Add moveaxis --- array_api_strict/__init__.py | 3 ++- array_api_strict/_manipulation_functions.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 3c0e147..5110c72 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -275,6 +275,7 @@ concat, expand_dims, flip, + moveaxis, permute_dims, reshape, roll, @@ -282,7 +283,7 @@ stack, ) -__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"] +__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "reshape", "roll", "squeeze", "stack"] from ._searching_functions import argmax, argmin, nonzero, where diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index af9a3dd..c22ea1b 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -2,6 +2,7 @@ from ._array_object import Array from ._data_type_functions import result_type +from ._flags import requires_api_version from typing import TYPE_CHECKING @@ -43,6 +44,20 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> """ return Array._new(np.flip(x._array, axis=axis)) +@requires_api_version('2023.12') +def moveaxis( + x: Array, + source: Union[int, Tuple[int, ...]], + destination: Union[int, Tuple[int, ...]], + /, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.moveaxis `. + + See its docstring for more information. + """ + return Array._new(np.moveaxis(x._array, source, destination)) + # Note: The function name is different here (see also matrix_transpose). # Unlike transpose(), the axes argument is required. From 993805957ba9849d66bdb79e7a18870bd31d96ee Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 23 Apr 2024 16:49:31 -0600 Subject: [PATCH 053/252] Add repeat() --- array_api_strict/__init__.py | 3 ++- array_api_strict/_flags.py | 6 ++--- array_api_strict/_manipulation_functions.py | 27 +++++++++++++++++++-- array_api_strict/tests/test_flags.py | 11 +++++++-- 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 5110c72..39eafc0 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -277,13 +277,14 @@ flip, moveaxis, permute_dims, + repeat, reshape, roll, squeeze, stack, ) -__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "reshape", "roll", "squeeze", "stack"] +__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack"] from ._searching_functions import argmax, argmin, nonzero, where diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 476ffb9..fd36139 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -81,10 +81,10 @@ def set_array_api_strict_flags( The functions that make use of data-dependent shapes, and are therefore disabled by setting this flag to False are - - `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`. - - `nonzero` + - `unique_all()`, `unique_counts()`, `unique_inverse()`, and `unique_values()`. + - `nonzero()` - Boolean array indexing - - `repeat` when the `repeats` argument is an array (requires 2023.12 + - `repeat()` when the `repeats` argument is an array (requires 2023.12 version of the standard) See diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index c22ea1b..1f9a50f 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -1,8 +1,9 @@ from __future__ import annotations from ._array_object import Array +from ._creation_functions import asarray from ._data_type_functions import result_type -from ._flags import requires_api_version +from ._flags import requires_api_version, get_array_api_strict_flags from typing import TYPE_CHECKING @@ -58,7 +59,6 @@ def moveaxis( """ return Array._new(np.moveaxis(x._array, source, destination)) - # Note: The function name is different here (see also matrix_transpose). # Unlike transpose(), the axes argument is required. def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: @@ -69,6 +69,29 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: """ return Array._new(np.transpose(x._array, axes)) +@requires_api_version('2023.12') +def repeat( + x: Array, + repeats: Union[int, Array], + /, + *, + axis: Optional[int] = None, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.repeat `. + + See its docstring for more information. + """ + if isinstance(repeats, Array): + data_dependent_shapes = get_array_api_strict_flags()['data_dependent_shapes'] + if not data_dependent_shapes: + raise RuntimeError("repeat() with repeats as an array requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") + elif isinstance(repeats, int): + repeats = asarray(repeats) + else: + raise TypeError("repeats must be an int or array") + + return Array._new(np.repeat(x._array, repeats, axis=axis)) # Note: the optional argument is called 'shape', not 'newshape' def reshape(x: Array, diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index f6fbc0d..2eba40c 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -2,7 +2,7 @@ reset_array_api_strict_flags) from .. import (asarray, unique_all, unique_counts, unique_inverse, - unique_values, nonzero) + unique_values, nonzero, repeat) import array_api_strict as xp @@ -102,8 +102,12 @@ def test_api_version(): assert xp.__array_api_version__ == '2021.12' def test_data_dependent_shapes(): + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') # to enable repeat() + a = asarray([0, 0, 1, 2, 2]) mask = asarray([True, False, True, False, True]) + repeats = asarray([1, 1, 2, 2, 2]) # Should not error unique_all(a) @@ -112,7 +116,8 @@ def test_data_dependent_shapes(): unique_values(a) nonzero(a) a[mask] - # TODO: add repeat when it is implemented + repeat(a, repeats) + repeat(a, 2) set_array_api_strict_flags(data_dependent_shapes=False) @@ -122,6 +127,8 @@ def test_data_dependent_shapes(): pytest.raises(RuntimeError, lambda: unique_values(a)) pytest.raises(RuntimeError, lambda: nonzero(a)) pytest.raises(RuntimeError, lambda: a[mask]) + pytest.raises(RuntimeError, lambda: repeat(a, repeats)) + repeat(a, 2) # Should never error linalg_examples = { 'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)), From 095be2f32a8373a95ae094d249fdaa0f765cb30c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 23 Apr 2024 16:53:58 -0600 Subject: [PATCH 054/252] Require the repeats array to have an integer dtype NumPy allows it to be bool (casting it to int). --- array_api_strict/_manipulation_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 1f9a50f..3380b4e 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -3,6 +3,7 @@ from ._array_object import Array from ._creation_functions import asarray from ._data_type_functions import result_type +from ._dtypes import _integer_dtypes from ._flags import requires_api_version, get_array_api_strict_flags from typing import TYPE_CHECKING @@ -86,6 +87,8 @@ def repeat( data_dependent_shapes = get_array_api_strict_flags()['data_dependent_shapes'] if not data_dependent_shapes: raise RuntimeError("repeat() with repeats as an array requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") + if repeats.dtype not in _integer_dtypes: + raise TypeError("The repeats array must have an integer dtype") elif isinstance(repeats, int): repeats = asarray(repeats) else: From 1c4460d8abeeaf8235ef1fd081ff1a56667f844e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 24 Apr 2024 16:13:53 -0600 Subject: [PATCH 055/252] Add searchsorted As far as I can tell, except for the dtype restriction, the standard is the same as NumPy. --- array_api_strict/__init__.py | 4 ++-- array_api_strict/_searching_functions.py | 24 ++++++++++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 39eafc0..d9a4aab 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -286,9 +286,9 @@ __all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack"] -from ._searching_functions import argmax, argmin, nonzero, where +from ._searching_functions import argmax, argmin, nonzero, searchsorted, where -__all__ += ["argmax", "argmin", "nonzero", "where"] +__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"] from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 1ef2556..89e50f3 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -2,11 +2,11 @@ from ._array_object import Array from ._dtypes import _result_type, _real_numeric_dtypes -from ._flags import requires_data_dependent_shapes +from ._flags import requires_data_dependent_shapes, requires_api_version from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Tuple + from typing import Literal, Optional, Tuple import numpy as np @@ -45,6 +45,26 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: raise ValueError("nonzero is not allowed on 0-dimensional arrays") return tuple(Array._new(i) for i in np.nonzero(x._array)) +@requires_api_version('2023.12') +def searchsorted( + x1: Array, + x2: Array, + /, + *, + side: Literal["left", "right"] = "left", + sorter: Optional[Array] = None, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.searchsorted `. + + See its docstring for more information. + """ + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in searchsorted") + sorter = sorter._array if sorter is not None else None + # TODO: The sort order of nans and signed zeros is implementation + # dependent. Should we error/warn if they are present? + return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter)) def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ From 730e71616a22e05d3ce88654233b6ebefa794774 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 24 Apr 2024 16:18:54 -0600 Subject: [PATCH 056/252] Add comment about x1 being 1-D in searchsorted --- array_api_strict/_searching_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 89e50f3..7314895 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -64,6 +64,8 @@ def searchsorted( sorter = sorter._array if sorter is not None else None # TODO: The sort order of nans and signed zeros is implementation # dependent. Should we error/warn if they are present? + + # x1 must be 1-D, but NumPy already requires this. return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter)) def where(condition: Array, x1: Array, x2: Array, /) -> Array: From f26bd499f4ec2e68a6555258c8868bbd02e71586 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 24 Apr 2024 16:19:57 -0600 Subject: [PATCH 057/252] Add signbit --- array_api_strict/__init__.py | 2 ++ array_api_strict/_elementwise_functions.py | 12 ++++++++++++ array_api_strict/tests/test_elementwise_functions.py | 1 + 3 files changed, 15 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index d9a4aab..7c9bbef 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -177,6 +177,7 @@ remainder, round, sign, + signbit, sin, sinh, square, @@ -244,6 +245,7 @@ "remainder", "round", "sign", + "signbit", "sin", "sinh", "square", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index a82818b..9ef71bd 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -791,6 +791,18 @@ def sign(x: Array, /) -> Array: return Array._new(np.sign(x._array)) +@requires_api_version('2023.12') +def signbit(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.signbit `. + + See its docstring for more information. + """ + if x.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in signbit") + return Array._new(np.signbit(x._array)) + + def sin(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sin `. diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 6b4a5ec..90994f3 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -75,6 +75,7 @@ def nargs(func): "remainder": "real numeric", "round": "numeric", "sign": "numeric", + "signbit": "real floating-point", "sin": "floating-point", "sinh": "floating-point", "sqrt": "floating-point", From dc1baad8ef7bfe07d1563647952930653e24c418 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 24 Apr 2024 16:26:18 -0600 Subject: [PATCH 058/252] Add tile() --- array_api_strict/__init__.py | 3 ++- array_api_strict/_manipulation_functions.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 7c9bbef..5cd6e53 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -284,9 +284,10 @@ roll, squeeze, stack, + tile, ) -__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack"] +__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile"] from ._searching_functions import argmax, argmin, nonzero, searchsorted, where diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 3380b4e..ee6066f 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -154,3 +154,16 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> result_type(*arrays) arrays = tuple(a._array for a in arrays) return Array._new(np.stack(arrays, axis=axis)) + + +@requires_api_version('2023.12') +def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.tile `. + + See its docstring for more information. + """ + # Note: NumPy allows repetitions to be an int or array + if not isinstance(repetitions, tuple): + raise TypeError("repetitions must be a tuple") + return Array._new(np.tile(x._array, repetitions)) From a30536b8a32f6e15031d629dbc4619595aaf2b6d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 24 Apr 2024 16:34:38 -0600 Subject: [PATCH 059/252] Add unstack() --- array_api_strict/__init__.py | 3 ++- array_api_strict/_manipulation_functions.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 5cd6e53..17cb2c3 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -285,9 +285,10 @@ squeeze, stack, tile, + unstack, ) -__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile"] +__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"] from ._searching_functions import argmax, argmin, nonzero, searchsorted, where diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index ee6066f..7652028 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -167,3 +167,15 @@ def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array: if not isinstance(repetitions, tuple): raise TypeError("repetitions must be a tuple") return Array._new(np.tile(x._array, repetitions)) + +# Note: this function is new +@requires_api_version('2023.12') +def unstack(x: Array, /, *, axis: int = 0) -> Tuple[Array, ...]: + if not (-x.ndim <= axis < x.ndim): + raise ValueError("axis out of range") + + if axis < 0: + axis += x.ndim + + slices = (slice(None),) * axis + return tuple(x[slices + (i, ...)] for i in range(x.shape[axis])) From 943756f4ce03dac99fcda412e79be05221063913 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 25 Apr 2024 14:29:47 -0600 Subject: [PATCH 060/252] Make boolean_indexing a separate flag from data_dependent_shapes It is separate in the inspection API, so we try to match that. --- array_api_strict/_array_object.py | 4 +-- array_api_strict/_flags.py | 42 ++++++++++++++++++++++------ array_api_strict/tests/test_flags.py | 17 +++++++++++ docs/api.rst | 4 +++ 4 files changed, 57 insertions(+), 10 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index f77c38d..8849ce3 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -436,8 +436,8 @@ def _validate_index(self, key): f"{len(key)=}, but masking is only specified in the " "Array API when the array is the sole index." ) - if not get_array_api_strict_flags()['data_dependent_shapes']: - raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") + if not get_array_api_strict_flags()['boolean_indexing']: + raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the boolean_indexing flag has been disabled for array-api-strict") elif i.dtype in _integer_dtypes and i.ndim != 0: raise IndexError( diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 6cc503a..3bf5664 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -25,6 +25,8 @@ API_VERSION = default_version = "2022.12" +BOOLEAN_INDEXING = True + DATA_DEPENDENT_SHAPES = True all_extensions = ( @@ -46,6 +48,7 @@ def set_array_api_strict_flags( *, api_version=None, + boolean_indexing=None, data_dependent_shapes=None, enabled_extensions=None, ): @@ -67,6 +70,12 @@ def set_array_api_strict_flags( Note that 2021.12 is supported, but currently gives the same thing as 2022.12 (except that the fft extension will be disabled). + + - `boolean_indexing`: Whether indexing by a boolean array is supported. + Note that although boolean array indexing does result in data-dependent + shapes, this flag is independent of the `data_dependent_shapes` flag + (see below). + - `data_dependent_shapes`: Whether data-dependent shapes are enabled in array-api-strict. @@ -79,10 +88,12 @@ def set_array_api_strict_flags( - `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`. - `nonzero` - - Boolean array indexing - `repeat` when the `repeats` argument is an array (requires 2023.12 version of the standard) + Note that while boolean indexing is also data-dependent, it is + controlled by a separate `boolean_indexing` flag (see above). + See https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html for more details. @@ -102,8 +113,8 @@ def set_array_api_strict_flags( >>> # Set the standard version to 2021.12 >>> set_array_api_strict_flags(api_version="2021.12") - >>> # Disable data-dependent shapes - >>> set_array_api_strict_flags(data_dependent_shapes=False) + >>> # Disable data-dependent shapes and boolean indexing + >>> set_array_api_strict_flags(data_dependent_shapes=False, boolean_indexing=False) >>> # Enable only the linalg extension (disable the fft extension) >>> set_array_api_strict_flags(enabled_extensions=["linalg"]) @@ -116,7 +127,7 @@ def set_array_api_strict_flags( ArrayAPIStrictFlags: A context manager to temporarily set the flags. """ - global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS + global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS if api_version is not None: if api_version not in supported_versions: @@ -126,6 +137,9 @@ def set_array_api_strict_flags( API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION + if boolean_indexing is not None: + BOOLEAN_INDEXING = boolean_indexing + if data_dependent_shapes is not None: DATA_DEPENDENT_SHAPES = data_dependent_shapes @@ -169,7 +183,11 @@ def get_array_api_strict_flags(): >>> from array_api_strict import get_array_api_strict_flags >>> flags = get_array_api_strict_flags() >>> flags - {'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')} + {'api_version': '2022.12', + 'boolean_indexing': True, + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft') + } See Also -------- @@ -181,6 +199,7 @@ def get_array_api_strict_flags(): """ return { "api_version": API_VERSION, + "boolean_indexing": BOOLEAN_INDEXING, "data_dependent_shapes": DATA_DEPENDENT_SHAPES, "enabled_extensions": ENABLED_EXTENSIONS, } @@ -215,9 +234,10 @@ def reset_array_api_strict_flags(): ArrayAPIStrictFlags: A context manager to temporarily set the flags. """ - global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS + global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS API_VERSION = default_version array_api_strict.__array_api_version__ = API_VERSION + BOOLEAN_INDEXING = True DATA_DEPENDENT_SHAPES = True ENABLED_EXTENSIONS = default_extensions @@ -242,10 +262,11 @@ class ArrayAPIStrictFlags: reset_array_api_strict_flags: Reset the flags to their default values. """ - def __init__(self, *, api_version=None, data_dependent_shapes=None, - enabled_extensions=None): + def __init__(self, *, api_version=None, boolean_indexing=None, + data_dependent_shapes=None, enabled_extensions=None): self.kwargs = { "api_version": api_version, + "boolean_indexing": boolean_indexing, "data_dependent_shapes": data_dependent_shapes, "enabled_extensions": enabled_extensions, } @@ -265,6 +286,11 @@ def set_flags_from_environment(): api_version=os.environ["ARRAY_API_STRICT_API_VERSION"] ) + if "ARRAY_API_STRICT_BOOLEAN_INDEXING" in os.environ: + set_array_api_strict_flags( + boolean_indexing=os.environ["ARRAY_API_STRICT_BOOLEAN_INDEXING"].lower() == "true" + ) + if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: set_array_api_strict_flags( data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 303c930..dcf4522 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -13,6 +13,7 @@ def test_flags(): flags = get_array_api_strict_flags() assert flags == { 'api_version': '2022.12', + 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), } @@ -22,6 +23,7 @@ def test_flags(): flags = get_array_api_strict_flags() assert flags == { 'api_version': '2022.12', + 'boolean_indexing': True, 'data_dependent_shapes': False, 'enabled_extensions': ('linalg', 'fft'), } @@ -29,6 +31,7 @@ def test_flags(): flags = get_array_api_strict_flags() assert flags == { 'api_version': '2022.12', + 'boolean_indexing': True, 'data_dependent_shapes': False, 'enabled_extensions': ('fft',), } @@ -41,6 +44,7 @@ def test_flags(): flags = get_array_api_strict_flags() assert flags == { 'api_version': '2021.12', + 'boolean_indexing': True, 'data_dependent_shapes': False, 'enabled_extensions': ('linalg',), } @@ -58,12 +62,14 @@ def test_flags(): with pytest.warns(UserWarning): set_array_api_strict_flags( api_version='2021.12', + boolean_indexing=False, data_dependent_shapes=False, enabled_extensions=()) reset_array_api_strict_flags() flags = get_array_api_strict_flags() assert flags == { 'api_version': '2022.12', + 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), } @@ -96,6 +102,17 @@ def test_data_dependent_shapes(): pytest.raises(RuntimeError, lambda: unique_inverse(a)) pytest.raises(RuntimeError, lambda: unique_values(a)) pytest.raises(RuntimeError, lambda: nonzero(a)) + a[mask] # No error (boolean indexing is a separate flag) + +def test_boolean_indexing(): + a = asarray([0, 0, 1, 2, 2]) + mask = asarray([True, False, True, False, True]) + + # Should not error + a[mask] + + set_array_api_strict_flags(boolean_indexing=False) + pytest.raises(RuntimeError, lambda: a[mask]) linalg_examples = { diff --git a/docs/api.rst b/docs/api.rst index e703a63..15ce4e9 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -30,6 +30,10 @@ used by array-api-strict initially. They will not change the defaults used by A string representing the version number. +.. envvar:: ARRAY_API_STRICT_BOOLEAN_INDEXING + + "True" or "False" to enable or disable boolean indexing. + .. envvar:: ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES "True" or "False" to enable or disable data dependent shapes. From 161acaa38d29add635a5fb1e52af561c957f4d40 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 25 Apr 2024 16:20:03 -0600 Subject: [PATCH 061/252] Add the inspection APIs --- array_api_strict/__init__.py | 6 ++ array_api_strict/_flags.py | 8 ++ array_api_strict/_info.py | 141 +++++++++++++++++++++++++++++++++++ array_api_strict/_typing.py | 38 ++++++++++ 4 files changed, 193 insertions(+) create mode 100644 array_api_strict/_info.py diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 17cb2c3..82a3cdd 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -260,6 +260,12 @@ __all__ += ["take"] +from ._info import __array_namespace_info__ + +__all__ += [ + "__array_namespace_info__", +] + # linalg is an extension in the array API spec, which is a sub-namespace. Only # a subset of functions in it are imported into the top-level namespace. from . import linalg diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 76dc96e..632c42b 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -178,6 +178,14 @@ def get_array_api_strict_flags(): This function is **not** part of the array API standard. It only exists in array-api-strict. + .. note:: + + The `inspection API + `__ + provides a portable way to access most of this information. However, it + is only present in standard versions starting with 2023.12. The array + API version can be accessed portably using `xp.__array_api_version__`. + Returns ------- dict diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py new file mode 100644 index 0000000..5f8c841 --- /dev/null +++ b/array_api_strict/_info.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +__all__ = [ + "__array_namespace_info__", + "capabilities", + "default_device", + "default_dtypes", + "devices", + "dtypes", +] + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Optional, Union, Tuple, List + from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info + +from ._array_object import CPU_DEVICE +from ._flags import get_array_api_strict_flags, requires_api_version +from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 + +@requires_api_version('2023.12') +def __array_namespace_info__() -> Info: + import array_api_strict._info + return array_api_strict._info + +@requires_api_version('2023.12') +def capabilities() -> Capabilities: + flags = get_array_api_strict_flags() + return {"boolean indexing": flags['boolean_indexing'], + "data-dependent shapes": flags['data_dependent_shapes'], + } + +@requires_api_version('2023.12') +def default_device() -> device: + return CPU_DEVICE + +@requires_api_version('2023.12') +def default_dtypes( + *, + device: Optional[device] = None, +) -> DefaultDataTypes: + return { + "real floating": float64, + "complex floating": complex128, + "integral": int64, + "indexing": int64, + } + +@requires_api_version('2023.12') +def dtypes( + *, + device: Optional[device] = None, + kind: Optional[Union[str, Tuple[str, ...]]] = None, +) -> DataTypes: + if kind is None: + return { + "bool": bool, + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + } + if kind == "unsigned integer": + return { + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "integral": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "real floating": + return { + "float32": float32, + "float64": float64, + } + if kind == "complex floating": + return { + "complex64": complex64, + "complex128": complex128, + } + if kind == "numeric": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + +@requires_api_version('2023.12') +def devices() -> List[device]: + return [CPU_DEVICE] + +__all__ = [ + "capabilities", + "default_device", + "default_dtypes", + "devices", + "dtypes", +] diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index ce25d4c..eb1b834 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -21,6 +21,8 @@ from typing import ( Any, + ModuleType, + TypedDict, TypeVar, Protocol, ) @@ -39,6 +41,8 @@ def __len__(self, /) -> int: ... Dtype = _DType +Info = ModuleType + if sys.version_info >= (3, 12): from collections.abc import Buffer as SupportsBufferProtocol else: @@ -48,3 +52,37 @@ def __len__(self, /) -> int: ... class SupportsDLPack(Protocol): def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... + +Capabilities = TypedDict( + "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool} +) + +DefaultDataTypes = TypedDict( + "DefaultDataTypes", + { + "real floating": Dtype, + "complex floating": Dtype, + "integral": Dtype, + "indexing": Dtype, + }, +) + +DataTypes = TypedDict( + "DataTypes", + { + "bool": Dtype, + "float32": Dtype, + "float64": Dtype, + "complex64": Dtype, + "complex128": Dtype, + "int8": Dtype, + "int16": Dtype, + "int32": Dtype, + "int64": Dtype, + "uint8": Dtype, + "uint16": Dtype, + "uint32": Dtype, + "uint64": Dtype, + }, + total=False, +) From 4d3ff6c3ebe5cea5d15d453cf749ca93a3882540 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 26 Apr 2024 14:34:44 -0600 Subject: [PATCH 062/252] Fix test failures --- array_api_strict/tests/test_flags.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index f1b20cc..0cb670b 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -55,6 +55,7 @@ def test_flags(): flags = get_array_api_strict_flags() assert flags == { 'api_version': '2021.12', + 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg',), } @@ -68,6 +69,7 @@ def test_flags(): flags = get_array_api_strict_flags() assert flags == { 'api_version': '2023.12', + 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), } @@ -132,6 +134,8 @@ def test_data_dependent_shapes(): pytest.raises(RuntimeError, lambda: unique_inverse(a)) pytest.raises(RuntimeError, lambda: unique_values(a)) pytest.raises(RuntimeError, lambda: nonzero(a)) + pytest.raises(RuntimeError, lambda: repeat(a, repeats)) + repeat(a, 2) # Should never error a[mask] # No error (boolean indexing is a separate flag) def test_boolean_indexing(): @@ -144,8 +148,6 @@ def test_boolean_indexing(): set_array_api_strict_flags(boolean_indexing=False) pytest.raises(RuntimeError, lambda: a[mask]) - pytest.raises(RuntimeError, lambda: repeat(a, repeats)) - repeat(a, 2) # Should never error linalg_examples = { 'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)), From 84d2aa5a94fc48627ba454edb337d8f70912a5c8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 26 Apr 2024 15:52:30 -0600 Subject: [PATCH 063/252] Always make warnings errors in the tests We might need to remove this if we ever test things that NumPy raises warnings for. --- pytest.ini | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 pytest.ini diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..0c84ee3 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +filterwarnings = error From 05fa0b5f8d6fd11e3be166bb1e75e2e3a55bc95a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 26 Apr 2024 15:56:58 -0600 Subject: [PATCH 064/252] Add tests that the new 2023.12 functions are properly decorated --- array_api_strict/tests/test_flags.py | 38 ++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 0cb670b..65aa26f 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -1,5 +1,7 @@ from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, reset_array_api_strict_flags) +from .._info import (capabilities, default_device, default_dtypes, devices, + dtypes) from .. import (asarray, unique_all, unique_counts, unique_inverse, unique_values, nonzero, repeat) @@ -237,3 +239,39 @@ def test_fft(func_name): set_array_api_strict_flags(enabled_extensions=('fft',)) func() + +api_version_2023_12_examples = { + '__array_namespace_info__': lambda: xp.__array_namespace_info__(), + # Test these functions directly to ensure they are properly decorated + 'capabilities': capabilities, + 'default_device': default_device, + 'default_dtypes': default_dtypes, + 'devices': devices, + 'dtypes': dtypes, + 'clip': lambda: xp.clip(xp.asarray([1, 2, 3]), 1, 2), + 'copysign': lambda: xp.copysign(xp.asarray([1., 2., 3.]), xp.asarray([-1., -1., -1.])), + 'cumulative_sum': lambda: xp.cumulative_sum(xp.asarray([1, 2, 3])), + 'hypot': lambda: xp.hypot(xp.asarray([3., 4.]), xp.asarray([4., 3.])), + 'maximum': lambda: xp.maximum(xp.asarray([1, 2, 3]), xp.asarray([2, 3, 4])), + 'minimum': lambda: xp.minimum(xp.asarray([1, 2, 3]), xp.asarray([2, 3, 4])), + 'moveaxis': lambda: xp.moveaxis(xp.ones((3, 3)), 0, 1), + 'repeat': lambda: xp.repeat(xp.asarray([1, 2, 3]), 3), + 'searchsorted': lambda: xp.searchsorted(xp.asarray([1, 2, 3]), xp.asarray([0, 1, 2, 3, 4])), + 'signbit': lambda: xp.signbit(xp.asarray([-1., 0., 1.])), + 'tile': lambda: xp.tile(xp.ones((3, 3)), (2, 3)), + 'unstack': lambda: xp.unstack(xp.ones((3, 3)), axis=0), +} + +@pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys()) +def test_api_version_2023_12(func_name): + func = api_version_2023_12_examples[func_name] + + # By default, these functions should error + pytest.raises(RuntimeError, func) + + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') + func() + + set_array_api_strict_flags(api_version='2022.12') + pytest.raises(RuntimeError, func) From 83331076bfa5824942af6bf9cc27b38903d6dc85 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 26 Apr 2024 16:04:14 -0600 Subject: [PATCH 065/252] Update documentation for 2023.12 support --- array_api_strict/_flags.py | 3 +-- docs/api.rst | 2 ++ docs/index.md | 17 +++++++++++------ 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 632c42b..c0b744e 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -71,10 +71,9 @@ def set_array_api_strict_flags( Note that 2021.12 is supported, but currently gives the same thing as 2022.12 (except that the fft extension will be disabled). - 2023.12 support is preliminary. Some features in 2023.12 may still be + 2023.12 support is experimental. Some features in 2023.12 may still be missing, and it hasn't been fully tested. - - `boolean_indexing`: Whether indexing by a boolean array is supported. Note that although boolean array indexing does result in data-dependent shapes, this flag is independent of the `data_dependent_shapes` flag diff --git a/docs/api.rst b/docs/api.rst index 15ce4e9..ed702dc 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -11,6 +11,8 @@ Array API Strict Flags .. currentmodule:: array_api_strict .. autofunction:: get_array_api_strict_flags + +.. _set_array_api_strict_flags: .. autofunction:: set_array_api_strict_flags .. autofunction:: reset_array_api_strict_flags .. autoclass:: ArrayAPIStrictFlags diff --git a/docs/index.md b/docs/index.md index 6e84efa..fc385d4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -15,9 +15,12 @@ libraries. Consuming library code should use the support the array API. Rather, it is intended to be used in the test suites of consuming libraries to test their array API usage. -array-api-strict currently supports the 2022.12 version of the standard. -2023.12 support is planned and is tracked by [this -issue](https://github.com/data-apis/array-api-strict/issues/25). +array-api-strict currently supports the +[2022.12](https://data-apis.org/array-api/latest/changelog.html#v2022-12) +version of the standard. Experimental +[2023.12](https://data-apis.org/array-api/latest/changelog.html#v2023-12) +support is implemented, [but must be enabled with a +flag](set_array_api_strict_flags). ## Install @@ -179,9 +182,11 @@ issue, but this hasn't necessarily been tested thoroughly. function. array-api-strict currently implements all of these. In the future, [there may be a way to disable them](https://github.com/data-apis/array-api-strict/issues/7). -6. array-api-strict currently only supports the 2022.12 version of the array - API standard. [Support for 2023.12 is - planned](https://github.com/data-apis/array-api-strict/issues/25). +6. array-api-strict currently uses the 2022.12 version of the array API + standard. Support for 2023.12 is implemented but is still experimental and + not fully tested. It can be enabled with + [`array_api_strict.set_array_api_strict_flags(api_version='2023.12')`](set_array_api_strict_flags). + (numpy.array_api)= ## Relationship to `numpy.array_api` From a437da32033e1dc09b040b52e178648c4b472b76 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 26 Apr 2024 23:58:50 -0600 Subject: [PATCH 066/252] Implement 2023.12 behavior for sum() and prod() --- array_api_strict/_statistical_functions.py | 34 ++++++++++++---------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index b35d26f..7a42d25 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -7,7 +7,7 @@ ) from ._array_object import Array from ._dtypes import float32, complex64 -from ._flags import requires_api_version +from ._flags import requires_api_version, get_array_api_strict_flags from ._creation_functions import zeros from ._manipulation_functions import concat @@ -89,14 +89,16 @@ def prod( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in prod") - # Note: sum() and prod() always upcast for dtype=None. `np.prod` does that - # for integers, but not for float32 or complex64, so we need to - # special-case it here + if dtype is None: - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 + # Note: In versions prior to 2023.12, sum() and prod() upcast for all + # dtypes when dtype=None. For 2023.12, the behavior is the same as in + # NumPy (only upcast for integral dtypes). + if get_array_api_strict_flags()['api_version'] < '2023.12': + if x.dtype == float32: + dtype = np.float64 + elif x.dtype == complex64: + dtype = np.complex128 else: dtype = dtype._np_dtype return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims)) @@ -126,14 +128,16 @@ def sum( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sum") - # Note: sum() and prod() always upcast for dtype=None. `np.sum` does that - # for integers, but not for float32 or complex64, so we need to - # special-case it here + if dtype is None: - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 + # Note: In versions prior to 2023.12, sum() and prod() upcast for all + # dtypes when dtype=None. For 2023.12, the behavior is the same as in + # NumPy (only upcast for integral dtypes). + if get_array_api_strict_flags()['api_version'] < '2023.12': + if x.dtype == float32: + dtype = np.float64 + elif x.dtype == complex64: + dtype = np.complex128 else: dtype = dtype._np_dtype return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims)) From 9f954e63265bfb1d655865ad390abe2ebe3ac585 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Apr 2024 14:28:48 -0600 Subject: [PATCH 067/252] Implement 2023.12 behavior for trace --- array_api_strict/linalg.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/array_api_strict/linalg.py b/array_api_strict/linalg.py index 1f548f0..3a0657e 100644 --- a/array_api_strict/linalg.py +++ b/array_api_strict/linalg.py @@ -11,7 +11,7 @@ from ._manipulation_functions import reshape from ._elementwise_functions import conj from ._array_object import Array -from ._flags import requires_extension +from ._flags import requires_extension, get_array_api_strict_flags try: from numpy._core.numeric import normalize_axis_tuple @@ -377,10 +377,11 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr # Note: trace() works the same as sum() and prod() (see # _statistical_functions.py) if dtype is None: - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 + if get_array_api_strict_flags()['api_version'] < '2023.12': + if x.dtype == float32: + dtype = np.float64 + elif x.dtype == complex64: + dtype = np.complex128 else: dtype = dtype._np_dtype # Note: trace always operates on the last two axes, whereas np.trace From 8572df37a1e0bdb543b505900be54e728d0bf79c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Apr 2024 14:29:06 -0600 Subject: [PATCH 068/252] Add a test for sum/trace/prod 2023.12 upcasting behavior --- .../tests/test_statistical_functions.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 array_api_strict/tests/test_statistical_functions.py diff --git a/array_api_strict/tests/test_statistical_functions.py b/array_api_strict/tests/test_statistical_functions.py new file mode 100644 index 0000000..fcf8f7f --- /dev/null +++ b/array_api_strict/tests/test_statistical_functions.py @@ -0,0 +1,27 @@ +import pytest + +import array_api_strict as xp + +@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace']) +def test_sum_prod_trace_2023_12(func_name): + # sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes + # with dtype=None + if func_name == 'trace': + func = getattr(xp.linalg, func_name) + else: + func = getattr(xp, func_name) + + a_real = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.float32) + a_complex = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.complex64) + a_int = xp.asarray([[1, 2], [3, 4]], dtype=xp.int32) + + assert func(a_real).dtype == xp.float64 + assert func(a_complex).dtype == xp.complex128 + assert func(a_int).dtype == xp.int64 + + with pytest.warns(UserWarning): + xp.set_array_api_strict_flags(api_version='2023.12') + + assert func(a_real).dtype == xp.float32 + assert func(a_complex).dtype == xp.complex64 + assert func(a_int).dtype == xp.int64 From 47894ff54bc9b0cd40018e105aa2ef99ff3dd19c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Apr 2024 15:14:37 -0600 Subject: [PATCH 069/252] Add 2023.12 axis restrictions to vecdot() and cross() --- array_api_strict/_linear_algebra_functions.py | 15 +- array_api_strict/linalg.py | 11 ++ array_api_strict/tests/test_linalg.py | 133 ++++++++++++++++++ .../tests/test_statistical_functions.py | 4 +- 4 files changed, 161 insertions(+), 2 deletions(-) create mode 100644 array_api_strict/tests/test_linalg.py diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py index 1ff08d4..6a1a921 100644 --- a/array_api_strict/_linear_algebra_functions.py +++ b/array_api_strict/_linear_algebra_functions.py @@ -8,8 +8,8 @@ from __future__ import annotations from ._dtypes import _numeric_dtypes - from ._array_object import Array +from ._flags import get_array_api_strict_flags from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -54,6 +54,19 @@ def matrix_transpose(x: Array, /) -> Array: def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in vecdot') + + if get_array_api_strict_flags()['api_version'] >= '2023.12': + if axis >= 0: + raise ValueError("axis must be negative in vecdot") + elif axis < min(-1, -x1.ndim, -x2.ndim): + raise ValueError("axis is out of bounds for x1 and x2") + + # In versions if the standard prior to 2023.12, vecdot applied axis after + # broadcasting. This is different from applying it before broadcasting + # when axis is nonnegative. The below code keeps this behavior for + # 2022.12, primarily for backwards compatibility. Note that the behavior + # is unambiguous when axis is negative, so the below code should work + # correctly in that case regardless of which version is used. ndim = max(x1.ndim, x2.ndim) x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) diff --git a/array_api_strict/linalg.py b/array_api_strict/linalg.py index 3a0657e..bd11aa4 100644 --- a/array_api_strict/linalg.py +++ b/array_api_strict/linalg.py @@ -80,6 +80,17 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: # Note: this is different from np.cross(), which allows dimension 2 if x1.shape[axis] != 3: raise ValueError('cross() dimension must equal 3') + + if get_array_api_strict_flags()['api_version'] >= '2023.12': + if axis >= 0: + raise ValueError("axis must be negative in cross") + elif axis < min(-1, -x1.ndim, -x2.ndim): + raise ValueError("axis is out of bounds for x1 and x2") + + # Prior to 2023.12, there was ambiguity in the standard about whether + # positive axis applied before or after broadcasting. NumPy applies + # the axis before broadcasting. Since that behavior is what has always + # been implemented here, we keep it for backwards compatibility. return Array._new(np.cross(x1._array, x2._array, axis=axis)) @requires_extension('linalg') diff --git a/array_api_strict/tests/test_linalg.py b/array_api_strict/tests/test_linalg.py new file mode 100644 index 0000000..5e6cda2 --- /dev/null +++ b/array_api_strict/tests/test_linalg.py @@ -0,0 +1,133 @@ +import pytest + +from .._flags import set_array_api_strict_flags + +import array_api_strict as xp + +# TODO: Maybe all of these exceptions should be IndexError? + +# Technically this is linear_algebra, not linalg, but it's simpler to keep +# both of these tests together +def test_vecdot_2023_12(): + # Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >= + # 0 behavior (which is primarily kept for backwards compatibility). + + a = xp.ones((2, 3, 4, 5)) + b = xp.ones(( 3, 4, 1)) + + # 2022.12 behavior, which is to apply axis >= 0 after broadcasting + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0)) + assert xp.linalg.vecdot(a, b, axis=1).shape == (2, 4, 5) + assert xp.linalg.vecdot(a, b, axis=2).shape == (2, 3, 5) + # This is disallowed because the arrays must have the same values before + # broadcasting + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-1)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-4)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=3)) + + # Out-of-bounds axes even after broadcasting + pytest.raises(IndexError, lambda: xp.linalg.vecdot(a, b, axis=4)) + pytest.raises(IndexError, lambda: xp.linalg.vecdot(a, b, axis=-5)) + + # negative axis behavior is unambiguous when it's within the bounds of + # both arrays before broadcasting + assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5) + assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5) + + # 2023.12 behavior, which is to only allow axis < 0 and axis >= + # min(x1.ndim, x2.ndim), which is unambiguous + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') + + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=1)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=2)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=3)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-1)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-4)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=4)) + pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-5)) + + assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5) + assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5) + +@pytest.mark.parametrize('api_version', ['2021.12', '2022.12', '2023.12']) +def test_cross(api_version): + # This test tests everything that should be the same across all supported + # API versions. + + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = xp.ones((2, 4, 5, 3)) + b = xp.ones(( 4, 1, 3)) + assert xp.linalg.cross(a, b, axis=-1).shape == (2, 4, 5, 3) + + a = xp.ones((2, 4, 3, 5)) + b = xp.ones(( 4, 3, 1)) + assert xp.linalg.cross(a, b, axis=-2).shape == (2, 4, 3, 5) + + # This is disallowed because the axes must equal 3 before broadcasting + a = xp.ones((3, 2, 3, 5)) + b = xp.ones(( 2, 1, 1)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-1)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-2)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-3)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-4)) + + # Out-of-bounds axes even after broadcasting + pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=4)) + pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=-5)) + +@pytest.mark.parametrize('api_version', ['2021.12', '2022.12']) +def test_cross_2022_12(api_version): + # Test the 2022.12 axis >= 0 behavior, which is primarily kept for + # backwards compatibility. Note that unlike vecdot, array_api_strict + # cross() never implemented the "after broadcasting" axis behavior, but + # just reused NumPy cross(), which applies axes before broadcasting. + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = xp.ones((3, 2, 4, 5)) + b = xp.ones((3, 2, 4, 1)) + assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5) + + # ambiguous case + a = xp.ones(( 3, 4, 5)) + b = xp.ones((3, 2, 4, 1)) + assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5) + +def test_cross_2023_12(): + # 2023.12 behavior, which is to only allow axis < 0 and axis >= + # min(x1.ndim, x2.ndim), which is unambiguous + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2023.12') + + a = xp.ones((3, 2, 4, 5)) + b = xp.ones((3, 2, 4, 1)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0)) + + a = xp.ones(( 3, 4, 5)) + b = xp.ones((3, 2, 4, 1)) + pytest.raises(ValueError, lambda: xp. linalg.cross(a, b, axis=0)) + + a = xp.ones((2, 4, 5, 3)) + b = xp.ones(( 4, 1, 3)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=1)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=2)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=3)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-2)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-3)) + pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-4)) + + pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=4)) + pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=-5)) + + assert xp.linalg.cross(a, b, axis=-1).shape == (2, 4, 5, 3) diff --git a/array_api_strict/tests/test_statistical_functions.py b/array_api_strict/tests/test_statistical_functions.py index fcf8f7f..61e848c 100644 --- a/array_api_strict/tests/test_statistical_functions.py +++ b/array_api_strict/tests/test_statistical_functions.py @@ -1,5 +1,7 @@ import pytest +from .._flags import set_array_api_strict_flags + import array_api_strict as xp @pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace']) @@ -20,7 +22,7 @@ def test_sum_prod_trace_2023_12(func_name): assert func(a_int).dtype == xp.int64 with pytest.warns(UserWarning): - xp.set_array_api_strict_flags(api_version='2023.12') + set_array_api_strict_flags(api_version='2023.12') assert func(a_real).dtype == xp.float32 assert func(a_complex).dtype == xp.complex64 From 6b431946e4135b3815d7ff050cac8c26c0ce6d5d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 29 Apr 2024 15:31:32 -0600 Subject: [PATCH 070/252] Add device flag to astype in 2023.12 Also clean up imports in test_data_type_functions.py --- array_api_strict/_data_type_functions.py | 20 +++++-- .../tests/test_data_type_functions.py | 54 ++++++++++++++----- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 41f70c5..7ae6244 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ._array_object import Array +from ._array_object import Array, CPU_DEVICE from ._dtypes import ( _DType, _all_dtypes, @@ -13,19 +13,31 @@ _numeric_dtypes, _result_type, ) +from ._flags import get_array_api_strict_flags from dataclasses import dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import List, Tuple, Union - from ._typing import Dtype + from typing import List, Tuple, Union, Optional + from ._typing import Dtype, Device import numpy as np +# Use to emulate the asarray(device) argument not existing in 2022.12 +_default = object() # Note: astype is a function, not an array method as in NumPy. -def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array: +def astype( + x: Array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = _default +) -> Array: + if device is not _default: + if get_array_api_strict_flags()['api_version'] >= '2023.12': + if device not in [CPU_DEVICE, None]: + raise ValueError(f"Unsupported device {device!r}") + else: + raise TypeError("The device argument to astype requires the 2023.12 version of the array API") + if not copy and dtype == x.dtype: return x return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy)) diff --git a/array_api_strict/tests/test_data_type_functions.py b/array_api_strict/tests/test_data_type_functions.py index 60a7f29..40cab55 100644 --- a/array_api_strict/tests/test_data_type_functions.py +++ b/array_api_strict/tests/test_data_type_functions.py @@ -3,38 +3,68 @@ import pytest from numpy.testing import assert_raises -import array_api_strict as xp import numpy as np +from .._creation_functions import asarray +from .._data_type_functions import astype, can_cast, isdtype +from .._dtypes import ( + bool, int8, int16, uint8, float64, +) +from .._flags import set_array_api_strict_flags + + @pytest.mark.parametrize( "from_, to, expected", [ - (xp.int8, xp.int16, True), - (xp.int16, xp.int8, False), - (xp.bool, xp.int8, False), - (xp.asarray(0, dtype=xp.uint8), xp.int8, False), + (int8, int16, True), + (int16, int8, False), + (bool, int8, False), + (asarray(0, dtype=uint8), int8, False), ], ) def test_can_cast(from_, to, expected): """ can_cast() returns correct result """ - assert xp.can_cast(from_, to) == expected + assert can_cast(from_, to) == expected def test_isdtype_strictness(): - assert_raises(TypeError, lambda: xp.isdtype(xp.float64, 64)) - assert_raises(ValueError, lambda: xp.isdtype(xp.float64, 'f8')) + assert_raises(TypeError, lambda: isdtype(float64, 64)) + assert_raises(ValueError, lambda: isdtype(float64, 'f8')) - assert_raises(TypeError, lambda: xp.isdtype(xp.float64, (('integral',),))) + assert_raises(TypeError, lambda: isdtype(float64, (('integral',),))) with assert_raises(TypeError), warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - xp.isdtype(xp.float64, np.object_) + isdtype(float64, np.object_) assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) - assert_raises(TypeError, lambda: xp.isdtype(xp.float64, None)) + assert_raises(TypeError, lambda: isdtype(float64, None)) with assert_raises(TypeError), warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - xp.isdtype(xp.float64, np.float64) + isdtype(float64, np.float64) assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) + + +@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) +def astype_device(api_version): + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = asarray([1, 2, 3], dtype=int8) + # Never an error + astype(a, int16) + + # Always an error + astype(a, int16, device="cpu") + + if api_version >= '2023.12': + astype(a, int8, device=None) + astype(a, int8, device=a.device) + else: + pytest.raises(TypeError, lambda: astype(a, int8, device=None)) + pytest.raises(TypeError, lambda: astype(a, int8, device=a.device)) From 3fde5ddfbe7b0e60d9ab2732676521e93a0e1a07 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 00:07:51 -0600 Subject: [PATCH 071/252] Factor out device checks into a helper function --- array_api_strict/_creation_functions.py | 75 +++++++++++++----------- array_api_strict/_data_type_functions.py | 6 +- 2 files changed, 43 insertions(+), 38 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index ad7ec82..dd3e74f 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -28,6 +28,14 @@ def _supports_buffer_protocol(obj): return False return True +def _check_device(device): + # _array_object imports in this file are inside the functions to avoid + # circular imports + from ._array_object import CPU_DEVICE + + if device not in [CPU_DEVICE, None]: + raise ValueError(f"Unsupported device {device!r}") + def asarray( obj: Union[ Array, @@ -48,16 +56,13 @@ def asarray( See its docstring for more information. """ - # _array_object imports in this file are inside the functions to avoid - # circular imports - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) _np_dtype = None if dtype is not None: _np_dtype = dtype._np_dtype - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) if np.__version__[0] < '2': if copy is False: @@ -106,11 +111,11 @@ def arange( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype)) @@ -127,11 +132,11 @@ def empty( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.empty(shape, dtype=dtype)) @@ -145,11 +150,11 @@ def empty_like( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.empty_like(x._array, dtype=dtype)) @@ -197,11 +202,11 @@ def full( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if isinstance(fill_value, Array) and fill_value.ndim == 0: fill_value = fill_value._array if dtype is not None: @@ -227,11 +232,11 @@ def full_like( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype res = np.full_like(x._array, fill_value, dtype=dtype) @@ -257,11 +262,11 @@ def linspace( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) @@ -298,11 +303,11 @@ def ones( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.ones(shape, dtype=dtype)) @@ -316,11 +321,11 @@ def ones_like( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.ones_like(x._array, dtype=dtype)) @@ -365,11 +370,11 @@ def zeros( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.zeros(shape, dtype=dtype)) @@ -383,11 +388,11 @@ def zeros_like( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.zeros_like(x._array, dtype=dtype)) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 7ae6244..e43125a 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations -from ._array_object import Array, CPU_DEVICE +from ._array_object import Array +from ._creation_functions import _check_device from ._dtypes import ( _DType, _all_dtypes, @@ -33,8 +34,7 @@ def astype( ) -> Array: if device is not _default: if get_array_api_strict_flags()['api_version'] >= '2023.12': - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) else: raise TypeError("The device argument to astype requires the 2023.12 version of the array API") From 1ac528821ac90d741cf7f0245a383e877366df5c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 00:08:03 -0600 Subject: [PATCH 072/252] Add 2023.12 device and copy keywords to from_dlpack The copy keyword just raises NotImplementedError for now. --- array_api_strict/_creation_functions.py | 28 ++++++++++++++++++++---- array_api_strict/_data_type_functions.py | 2 +- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index dd3e74f..b24af98 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -12,6 +12,7 @@ SupportsBufferProtocol, ) from ._dtypes import _DType, _all_dtypes +from ._flags import get_array_api_strict_flags import numpy as np @@ -174,19 +175,38 @@ def eye( See its docstring for more information. """ - from ._array_object import Array, CPU_DEVICE + from ._array_object import Array _check_valid_dtype(dtype) - if device not in [CPU_DEVICE, None]: - raise ValueError(f"Unsupported device {device!r}") + _check_device(device) + if dtype is not None: dtype = dtype._np_dtype return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype)) -def from_dlpack(x: object, /) -> Array: +_default = object() + +def from_dlpack( + x: object, + /, + *, + device: Optional[Device] = _default, + copy: Optional[bool] = _default, +) -> Array: from ._array_object import Array + if get_array_api_strict_flags()['api_version'] < '2023.12': + if device is not _default: + raise ValueError("The device argument to from_dlpack requires at least version 2023.12 of the array API") + if copy is not _default: + raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API") + + if device is not _default: + _check_device(device) + if copy not in [_default, None]: + raise NotImplementedError("The copy argument to from_dlpack is not yet implemented") + return Array._new(np.from_dlpack(x)) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index e43125a..3405710 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -36,7 +36,7 @@ def astype( if get_array_api_strict_flags()['api_version'] >= '2023.12': _check_device(device) else: - raise TypeError("The device argument to astype requires the 2023.12 version of the array API") + raise TypeError("The device argument to astype requires at least version 2023.12 of the array API") if not copy and dtype == x.dtype: return x From dc4684b4be7e70add3c197dff56deeef299d43bd Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 13:56:00 -0600 Subject: [PATCH 073/252] Update the signature of __dlpack__ for 2023.12 The new arguments are not actually supported yet, and probably won't be until upstream NumPy does. --- array_api_strict/_array_object.py | 28 ++++++++++++++++++++++++- array_api_strict/_creation_functions.py | 1 + 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 8849ce3..26c4330 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -51,6 +51,8 @@ def __repr__(self): CPU_DEVICE = _cpu_device() +_default = object() + class Array: """ n-d array object for the array API namespace. @@ -525,10 +527,34 @@ def __complex__(self: Array, /) -> complex: res = self._array.__complex__() return res - def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule: + def __dlpack__( + self: Array, + /, + *, + stream: Optional[Union[int, Any]] = None, + max_version: Optional[tuple[int, int]] = _default, + dl_device: Optional[tuple[IntEnum, int]] = _default, + copy: Optional[bool] = _default, + ) -> PyCapsule: """ Performs the operation __dlpack__. """ + if get_array_api_strict_flags()['api_version'] < '2023.12': + if max_version is not _default: + raise ValueError("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API") + if dl_device is not _default: + raise ValueError("The device argument to __dlpack__ requires at least version 2023.12 of the array API") + if copy is not _default: + raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") + + # Going to wait for upstream numpy support + if max_version not in [_default, None]: + raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") + if dl_device not in [_default, None]: + raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") + if copy not in [_default, None]: + raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") + return self._array.__dlpack__(stream=stream) def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index b24af98..0e85cdc 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -202,6 +202,7 @@ def from_dlpack( if copy is not _default: raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API") + # Going to wait for upstream numpy support if device is not _default: _check_device(device) if copy not in [_default, None]: From 647a5f004053a42203707297d609793a0ef25210 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 15:22:45 -0600 Subject: [PATCH 074/252] Add tests for from_dlpack and __dlpack__ 2023.12 behavior --- array_api_strict/tests/test_array_object.py | 32 +++++++++++++++++++ .../tests/test_creation_functions.py | 26 ++++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index a66637f..f0efdfa 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -23,6 +23,8 @@ uint64, bool as bool_, ) +from .._flags import set_array_api_strict_flags + import array_api_strict def test_validate_index(): @@ -420,3 +422,33 @@ def test_array_namespace(): pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) + + +@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) +def dlpack_2023_12(api_version): + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = asarray([1, 2, 3], dtype=int8) + # Never an error + a.__dlpack__() + + + exception = NotImplementedError if api_version >= '2023.12' else ValueError + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=CPU_DEVICE)) + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=None)) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=(1, 0))) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=None)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=False)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=True)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=None)) diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 78d4c80..819afad 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -3,6 +3,8 @@ from numpy.testing import assert_raises import numpy as np +import pytest + from .. import all from .._creation_functions import ( asarray, @@ -10,6 +12,7 @@ empty, empty_like, eye, + from_dlpack, full, full_like, linspace, @@ -21,7 +24,7 @@ ) from .._dtypes import float32, float64 from .._array_object import Array, CPU_DEVICE - +from .._flags import set_array_api_strict_flags def test_asarray_errors(): # Test various protections against incorrect usage @@ -188,3 +191,24 @@ def test_meshgrid_dtype_errors(): meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float32)) assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64))) + + +@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) +def from_dlpack_2023_12(api_version): + if api_version != '2022.12': + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + else: + set_array_api_strict_flags(api_version=api_version) + + a = asarray([1., 2., 3.], dtype=float64) + # Never an error + capsule = a.__dlpack__() + from_dlpack(capsule) + + exception = NotImplementedError if api_version >= '2023.12' else ValueError + pytest.raises(exception, lambda: from_dlpack(capsule, device=CPU_DEVICE)) + pytest.raises(exception, lambda: from_dlpack(capsule, device=None)) + pytest.raises(exception, lambda: from_dlpack(capsule, copy=False)) + pytest.raises(exception, lambda: from_dlpack(capsule, copy=True)) + pytest.raises(exception, lambda: from_dlpack(capsule, copy=None)) From 306de9bd5f2810636f0c7e1f83a119194b25f47c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 15:29:13 -0600 Subject: [PATCH 075/252] Add 2023.12 testing to the CI --- .github/workflows/array-api-tests.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index bfb7dcf..ce246e4 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -12,6 +12,7 @@ jobs: matrix: python-version: ['3.9', '3.10', '3.11', '3.12'] numpy-version: ['1.26', 'dev'] + api_version: ['2022.12', '2023.12'] exclude: - python-version: '3.8' numpy-version: 'dev' @@ -49,5 +50,12 @@ jobs: # tests fail in numpy 1.26 on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak run: | + export ARRAY_API_STRICT_API_VERSION=${{ matrix.api_version }} + + # Only signature tests work for now for 2023.12 + if [[ "${{ matrix.api_version }}" == "2023.12" ]]; then + PYTEST_ARGS="${PYTEST_ARGS} -k signature + fi + cd ${GITHUB_WORKSPACE}/array-api-tests pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS} From 44bbdb214fa555c8f8877b720d75ebc07d6c0afc Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 15:31:31 -0600 Subject: [PATCH 076/252] Better error message --- array_api_strict/_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 26c4330..0fff27a 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -439,7 +439,7 @@ def _validate_index(self, key): "Array API when the array is the sole index." ) if not get_array_api_strict_flags()['boolean_indexing']: - raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the boolean_indexing flag has been disabled for array-api-strict") + raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict") elif i.dtype in _integer_dtypes and i.ndim != 0: raise IndexError( From e5225ed7f54de3c0e82cc1a63c66f6093c5d204f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 15:34:13 -0600 Subject: [PATCH 077/252] Parameterize the API version in a loop instead of in the matrix --- .github/workflows/array-api-tests.yml | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index ce246e4..b37ec04 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -4,6 +4,7 @@ on: [push, pull_request] env: PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline --max-examples 200" + API_VERSIONS: "2022.12 2023.12" jobs: array-api-tests: @@ -12,7 +13,6 @@ jobs: matrix: python-version: ['3.9', '3.10', '3.11', '3.12'] numpy-version: ['1.26', 'dev'] - api_version: ['2022.12', '2023.12'] exclude: - python-version: '3.8' numpy-version: 'dev' @@ -50,12 +50,13 @@ jobs: # tests fail in numpy 1.26 on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak run: | - export ARRAY_API_STRICT_API_VERSION=${{ matrix.api_version }} + # Parameterizing this in the CI matrix is wasteful. Just do a loop here. + for ARRAY_API_STRICT_API_VERSION in ${API_VERSIONS}; do + # Only signature tests work for now for 2023.12 + if [[ "$ARRAY_API_STRICT_API_VERSION" == "2023.12" ]]; then + PYTEST_ARGS="${PYTEST_ARGS} -k signature + fi - # Only signature tests work for now for 2023.12 - if [[ "${{ matrix.api_version }}" == "2023.12" ]]; then - PYTEST_ARGS="${PYTEST_ARGS} -k signature - fi - - cd ${GITHUB_WORKSPACE}/array-api-tests - pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS} + cd ${GITHUB_WORKSPACE}/array-api-tests + pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS} + done From 3e0be7df7eff77d4a443b0b7d265d7c0e34ee4e5 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 3 May 2024 16:05:15 -0600 Subject: [PATCH 078/252] Ensure a.mT works even if the linalg extension is disabled --- array_api_strict/_array_object.py | 2 +- array_api_strict/tests/test_flags.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 0fff27a..18dd219 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -1159,7 +1159,7 @@ def device(self) -> Device: # Note: mT is new in array API spec (see matrix_transpose) @property def mT(self) -> Array: - from .linalg import matrix_transpose + from ._linear_algebra_functions import matrix_transpose return matrix_transpose(self) @property diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 65aa26f..b68e7aa 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -184,9 +184,10 @@ def test_boolean_indexing(): 'matrix_transpose': lambda: xp.matrix_transpose(xp.eye(3)), 'tensordot': lambda: xp.tensordot(xp.eye(3), xp.eye(3)), 'vecdot': lambda: xp.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'mT': lambda: xp.eye(3).mT, } -assert set(linalg_main_namespace_examples) == set(xp.__all__) & set(xp.linalg.__all__) +assert set(linalg_main_namespace_examples) == (set(xp.__all__) & set(xp.linalg.__all__)) | {"mT"} @pytest.mark.parametrize('func_name', linalg_examples.keys()) def test_linalg(func_name): From 338ebfefce1ff395391154b33851359dc32565c0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 1 May 2024 16:41:36 -0600 Subject: [PATCH 079/252] Don't allow environment variables to be set during test runs --- array_api_strict/_flags.py | 7 +++++++ array_api_strict/tests/conftest.py | 9 ++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index c0b744e..866e4f5 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -293,6 +293,13 @@ def __exit__(self, exc_type, exc_value, traceback): # Private functions +ENVIRONMENT_VARIABLES = [ + "ARRAY_API_STRICT_API_VERSION", + "ARRAY_API_STRICT_BOOLEAN_INDEXING", + "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES", + "ARRAY_API_STRICT_ENABLED_EXTENSIONS", +] + def set_flags_from_environment(): if "ARRAY_API_STRICT_API_VERSION" in os.environ: set_array_api_strict_flags( diff --git a/array_api_strict/tests/conftest.py b/array_api_strict/tests/conftest.py index 5000d5d..322675c 100644 --- a/array_api_strict/tests/conftest.py +++ b/array_api_strict/tests/conftest.py @@ -1,7 +1,14 @@ -from .._flags import reset_array_api_strict_flags +import os + +from .._flags import reset_array_api_strict_flags, ENVIRONMENT_VARIABLES import pytest +def pytest_sessionstart(session): + for env_var in ENVIRONMENT_VARIABLES: + if env_var in os.environ: + pytest.exit(f"ERROR: {env_var} is set. array-api-strict environment variables must not be set when the tests are run.") + @pytest.fixture(autouse=True) def reset_flags(): reset_array_api_strict_flags() From 6a466f45cbe30fdec4049dee4bfd6e352c90c4a0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 2 May 2024 21:55:58 -0600 Subject: [PATCH 080/252] Use a more robust way to fail the tests if an env var is set --- array_api_strict/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/tests/conftest.py b/array_api_strict/tests/conftest.py index 322675c..1a9d507 100644 --- a/array_api_strict/tests/conftest.py +++ b/array_api_strict/tests/conftest.py @@ -4,7 +4,7 @@ import pytest -def pytest_sessionstart(session): +def pytest_configure(config): for env_var in ENVIRONMENT_VARIABLES: if env_var in os.environ: pytest.exit(f"ERROR: {env_var} is set. array-api-strict environment variables must not be set when the tests are run.") From dd01b12c75e044ec3b0d3f1a50eb22885670f6ea Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 3 May 2024 15:59:15 -0600 Subject: [PATCH 081/252] Fix setting ARRAY_API_STRICT_ENABLED_EXTENSIONS='' --- array_api_strict/_flags.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 866e4f5..b02b869 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -317,9 +317,10 @@ def set_flags_from_environment(): ) if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ: - set_array_api_strict_flags( - enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") - ) + enabled_extensions = os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") + if enabled_extensions == [""]: + enabled_extensions = [] + set_array_api_strict_flags(enabled_extensions=enabled_extensions) set_flags_from_environment() From 43b9088ce2d7f33d50486e03b45fd99ef6166cd6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 3 May 2024 16:06:09 -0600 Subject: [PATCH 082/252] Make extensions give AttributeError when they are disabled This is how the test suite and presumably some other codes detect if extensions are enabled or not. This also dynamically updates __all__ whenever extensions are enabled or disabled. --- array_api_strict/__init__.py | 32 ++-- array_api_strict/{fft.py => _fft.py} | 0 array_api_strict/_flags.py | 7 + array_api_strict/{linalg.py => _linalg.py} | 0 array_api_strict/tests/test_flags.py | 168 ++++++++++++++++----- 5 files changed, 160 insertions(+), 47 deletions(-) rename array_api_strict/{fft.py => _fft.py} (100%) rename array_api_strict/{linalg.py => _linalg.py} (100%) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 82a3cdd..8dfa09f 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -16,13 +16,15 @@ """ +__all__ = [] + # Warning: __array_api_version__ could change globally with # set_array_api_strict_flags(). This should always be accessed as an # attribute, like xp.__array_api_version__, or using # array_api_strict.get_array_api_strict_flags()['api_version']. from ._flags import API_VERSION as __array_api_version__ -__all__ = ["__array_api_version__"] +__all__ += ["__array_api_version__"] from ._constants import e, inf, nan, pi, newaxis @@ -266,19 +268,10 @@ "__array_namespace_info__", ] -# linalg is an extension in the array API spec, which is a sub-namespace. Only -# a subset of functions in it are imported into the top-level namespace. -from . import linalg - -__all__ += ["linalg"] - from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot __all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"] -from . import fft -__all__ += ["fft"] - from ._manipulation_functions import ( concat, expand_dims, @@ -330,3 +323,22 @@ from . import _version __version__ = _version.get_versions()['version'] del _version + + +# Extensions can be enabled or disabled dynamically. In order to make +# "array_api_strict.linalg" give an AttributeError when it is disabled, we +# use __getattr__. Note that linalg and fft are dynamically added and removed +# from __all__ in set_array_api_strict_flags. + +def __getattr__(name): + if name in ['linalg', 'fft']: + if name in get_array_api_strict_flags()['enabled_extensions']: + if name == 'linalg': + from . import _linalg + return _linalg + elif name == 'fft': + from . import _fft + return _fft + else: + raise AttributeError(f"The {name!r} extension has been disabled for array_api_strict") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/array_api_strict/fft.py b/array_api_strict/_fft.py similarity index 100% rename from array_api_strict/fft.py rename to array_api_strict/_fft.py diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index b02b869..221d0d3 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -161,6 +161,10 @@ def set_array_api_strict_flags( else: ENABLED_EXTENSIONS = tuple([ext for ext in ENABLED_EXTENSIONS if extension_versions[ext] <= API_VERSION]) + array_api_strict.__all__[:] = sorted(set(ENABLED_EXTENSIONS) | + set(array_api_strict.__all__) - + set(default_extensions)) + # We have to do this separately or it won't get added as the docstring set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( supported_versions=supported_versions, @@ -321,6 +325,9 @@ def set_flags_from_environment(): if enabled_extensions == [""]: enabled_extensions = [] set_array_api_strict_flags(enabled_extensions=enabled_extensions) + else: + # Needed at first import to add linalg and fft to __all__ + set_array_api_strict_flags(enabled_extensions=default_extensions) set_flags_from_environment() diff --git a/array_api_strict/linalg.py b/array_api_strict/_linalg.py similarity index 100% rename from array_api_strict/linalg.py rename to array_api_strict/_linalg.py diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index b68e7aa..38b1a3b 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -1,7 +1,15 @@ +import sys +import subprocess + from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, reset_array_api_strict_flags) from .._info import (capabilities, default_device, default_dtypes, devices, dtypes) +from .._fft import (fft, ifft, fftn, ifftn, rfft, irfft, rfftn, irfftn, hfft, + ihfft, fftfreq, rfftfreq, fftshift, ifftshift) +from .._linalg import (cholesky, cross, det, diagonal, eigh, eigvalsh, inv, + matmul, matrix_norm, matrix_power, matrix_rank, matrix_transpose, outer, pinv, + qr, slogdet, solve, svd, svdvals, tensordot, trace, vecdot, vector_norm) from .. import (asarray, unique_all, unique_counts, unique_inverse, unique_values, nonzero, repeat) @@ -152,29 +160,29 @@ def test_boolean_indexing(): pytest.raises(RuntimeError, lambda: a[mask]) linalg_examples = { - 'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)), - 'cross': lambda: xp.linalg.cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])), - 'det': lambda: xp.linalg.det(xp.eye(3)), - 'diagonal': lambda: xp.linalg.diagonal(xp.eye(3)), - 'eigh': lambda: xp.linalg.eigh(xp.eye(3)), - 'eigvalsh': lambda: xp.linalg.eigvalsh(xp.eye(3)), - 'inv': lambda: xp.linalg.inv(xp.eye(3)), - 'matmul': lambda: xp.linalg.matmul(xp.eye(3), xp.eye(3)), - 'matrix_norm': lambda: xp.linalg.matrix_norm(xp.eye(3)), - 'matrix_power': lambda: xp.linalg.matrix_power(xp.eye(3), 2), - 'matrix_rank': lambda: xp.linalg.matrix_rank(xp.eye(3)), - 'matrix_transpose': lambda: xp.linalg.matrix_transpose(xp.eye(3)), - 'outer': lambda: xp.linalg.outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), - 'pinv': lambda: xp.linalg.pinv(xp.eye(3)), - 'qr': lambda: xp.linalg.qr(xp.eye(3)), - 'slogdet': lambda: xp.linalg.slogdet(xp.eye(3)), - 'solve': lambda: xp.linalg.solve(xp.eye(3), xp.eye(3)), - 'svd': lambda: xp.linalg.svd(xp.eye(3)), - 'svdvals': lambda: xp.linalg.svdvals(xp.eye(3)), - 'tensordot': lambda: xp.linalg.tensordot(xp.eye(3), xp.eye(3)), - 'trace': lambda: xp.linalg.trace(xp.eye(3)), - 'vecdot': lambda: xp.linalg.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), - 'vector_norm': lambda: xp.linalg.vector_norm(xp.asarray([1., 2., 3.])), + 'cholesky': lambda: cholesky(xp.eye(3)), + 'cross': lambda: cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])), + 'det': lambda: det(xp.eye(3)), + 'diagonal': lambda: diagonal(xp.eye(3)), + 'eigh': lambda: eigh(xp.eye(3)), + 'eigvalsh': lambda: eigvalsh(xp.eye(3)), + 'inv': lambda: inv(xp.eye(3)), + 'matmul': lambda: matmul(xp.eye(3), xp.eye(3)), + 'matrix_norm': lambda: matrix_norm(xp.eye(3)), + 'matrix_power': lambda: matrix_power(xp.eye(3), 2), + 'matrix_rank': lambda: matrix_rank(xp.eye(3)), + 'matrix_transpose': lambda: matrix_transpose(xp.eye(3)), + 'outer': lambda: outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'pinv': lambda: pinv(xp.eye(3)), + 'qr': lambda: qr(xp.eye(3)), + 'slogdet': lambda: slogdet(xp.eye(3)), + 'solve': lambda: solve(xp.eye(3), xp.eye(3)), + 'svd': lambda: svd(xp.eye(3)), + 'svdvals': lambda: svdvals(xp.eye(3)), + 'tensordot': lambda: tensordot(xp.eye(3), xp.eye(3)), + 'trace': lambda: trace(xp.eye(3)), + 'vecdot': lambda: vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'vector_norm': lambda: vector_norm(xp.asarray([1., 2., 3.])), } assert set(linalg_examples) == set(xp.linalg.__all__) @@ -210,20 +218,20 @@ def test_linalg(func_name): main_namespace_func() fft_examples = { - 'fft': lambda: xp.fft.fft(xp.asarray([0j, 1j, 0j, 0j])), - 'ifft': lambda: xp.fft.ifft(xp.asarray([0j, 1j, 0j, 0j])), - 'fftn': lambda: xp.fft.fftn(xp.asarray([[0j, 1j], [0j, 0j]])), - 'ifftn': lambda: xp.fft.ifftn(xp.asarray([[0j, 1j], [0j, 0j]])), - 'rfft': lambda: xp.fft.rfft(xp.asarray([0., 1., 0., 0.])), - 'irfft': lambda: xp.fft.irfft(xp.asarray([0j, 1j, 0j, 0j])), - 'rfftn': lambda: xp.fft.rfftn(xp.asarray([[0., 1.], [0., 0.]])), - 'irfftn': lambda: xp.fft.irfftn(xp.asarray([[0j, 1j], [0j, 0j]])), - 'hfft': lambda: xp.fft.hfft(xp.asarray([0j, 1j, 0j, 0j])), - 'ihfft': lambda: xp.fft.ihfft(xp.asarray([0., 1., 0., 0.])), - 'fftfreq': lambda: xp.fft.fftfreq(4), - 'rfftfreq': lambda: xp.fft.rfftfreq(4), - 'fftshift': lambda: xp.fft.fftshift(xp.asarray([0j, 1j, 0j, 0j])), - 'ifftshift': lambda: xp.fft.ifftshift(xp.asarray([0j, 1j, 0j, 0j])), + 'fft': lambda: fft(xp.asarray([0j, 1j, 0j, 0j])), + 'ifft': lambda: ifft(xp.asarray([0j, 1j, 0j, 0j])), + 'fftn': lambda: fftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'ifftn': lambda: ifftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'rfft': lambda: rfft(xp.asarray([0., 1., 0., 0.])), + 'irfft': lambda: irfft(xp.asarray([0j, 1j, 0j, 0j])), + 'rfftn': lambda: rfftn(xp.asarray([[0., 1.], [0., 0.]])), + 'irfftn': lambda: irfftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'hfft': lambda: hfft(xp.asarray([0j, 1j, 0j, 0j])), + 'ihfft': lambda: ihfft(xp.asarray([0., 1., 0., 0.])), + 'fftfreq': lambda: fftfreq(4), + 'rfftfreq': lambda: rfftfreq(4), + 'fftshift': lambda: fftshift(xp.asarray([0j, 1j, 0j, 0j])), + 'ifftshift': lambda: ifftshift(xp.asarray([0j, 1j, 0j, 0j])), } assert set(fft_examples) == set(xp.fft.__all__) @@ -276,3 +284,89 @@ def test_api_version_2023_12(func_name): set_array_api_strict_flags(api_version='2022.12') pytest.raises(RuntimeError, func) + +def test_disabled_extensions(): + # Test that xp.extension errors when an extension is disabled, and that + # xp.__all__ is updated properly. + + # First test that things are correct on the initial import. Since we have + # already called set_array_api_strict_flags many times throughout running + # the tests, we have to test this in a subprocess. + subprocess_tests = [('''\ +import array_api_strict + +array_api_strict.linalg # No error +array_api_strict.fft # No error +assert "linalg" in array_api_strict.__all__ +assert "fft" in array_api_strict.__all__ +assert len(array_api_strict.__all__) == len(set(array_api_strict.__all__)) +''', {}), +# Test that the initial population of __all__ works correctly +('''\ +from array_api_strict import * # No error +linalg # Should have been imported by the previous line +fft +''', {}), +('''\ +from array_api_strict import * # No error +linalg # Should have been imported by the previous line +assert 'fft' not in globals() +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": "linalg"}), +('''\ +from array_api_strict import * # No error +fft # Should have been imported by the previous line +assert 'linalg' not in globals() +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": "fft"}), +('''\ +from array_api_strict import * # No error +assert 'linalg' not in globals() +assert 'fft' not in globals() +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": ""}), +] + for test, env in subprocess_tests: + try: + subprocess.run([sys.executable, '-c', test], check=True, + capture_output=True, encoding='utf-8', env=env) + except subprocess.CalledProcessError as e: + print(e.stdout, end='') + # Ensure the exception is shown in the output log + raise AssertionError(e.stderr) + + assert 'linalg' in xp.__all__ + assert 'fft' in xp.__all__ + xp.linalg # No error + xp.fft # No error + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' in ns + assert 'fft' in ns + + set_array_api_strict_flags(enabled_extensions=('linalg',)) + assert 'linalg' in xp.__all__ + assert 'fft' not in xp.__all__ + xp.linalg # No error + pytest.raises(AttributeError, lambda: xp.fft) + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' in ns + assert 'fft' not in ns + + set_array_api_strict_flags(enabled_extensions=('fft',)) + assert 'linalg' not in xp.__all__ + assert 'fft' in xp.__all__ + pytest.raises(AttributeError, lambda: xp.linalg) + xp.fft # No error + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' not in ns + assert 'fft' in ns + + set_array_api_strict_flags(enabled_extensions=()) + assert 'linalg' not in xp.__all__ + assert 'fft' not in xp.__all__ + pytest.raises(AttributeError, lambda: xp.linalg) + pytest.raises(AttributeError, lambda: xp.fft) + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' not in ns + assert 'fft' not in ns From 7bc29d6b638fcdc666028b4e33bd9529e77b4213 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 6 May 2024 15:04:03 -0600 Subject: [PATCH 083/252] Fix setting ARRAY_API_STRICT_API_VERSION to 2021.12 --- array_api_strict/_flags.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 221d0d3..f6cef29 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -305,29 +305,25 @@ def __exit__(self, exc_type, exc_value, traceback): ] def set_flags_from_environment(): + kwargs = {} if "ARRAY_API_STRICT_API_VERSION" in os.environ: - set_array_api_strict_flags( - api_version=os.environ["ARRAY_API_STRICT_API_VERSION"] - ) + kwargs["api_version"] = os.environ["ARRAY_API_STRICT_API_VERSION"] if "ARRAY_API_STRICT_BOOLEAN_INDEXING" in os.environ: - set_array_api_strict_flags( - boolean_indexing=os.environ["ARRAY_API_STRICT_BOOLEAN_INDEXING"].lower() == "true" - ) + kwargs["boolean_indexing"] = os.environ["ARRAY_API_STRICT_BOOLEAN_INDEXING"].lower() == "true" if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: - set_array_api_strict_flags( - data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" - ) + kwargs["data_dependent_shapes"] = os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ: enabled_extensions = os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") if enabled_extensions == [""]: enabled_extensions = [] - set_array_api_strict_flags(enabled_extensions=enabled_extensions) - else: - # Needed at first import to add linalg and fft to __all__ - set_array_api_strict_flags(enabled_extensions=default_extensions) + kwargs["enabled_extensions"] = enabled_extensions + + # Called unconditionally because it is needed at first import to add + # linalg and fft to __all__ + set_array_api_strict_flags(**kwargs) set_flags_from_environment() From c770c9b5503b7e80ba4aac8026ffe06bccbff9cb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 6 May 2024 15:04:19 -0600 Subject: [PATCH 084/252] Add tests for environment variables They're not pretty, but they get the job done. --- array_api_strict/tests/test_flags.py | 121 +++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 38b1a3b..86ad8e2 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -370,3 +370,124 @@ def test_disabled_extensions(): exec('from array_api_strict import *', ns) assert 'linalg' not in ns assert 'fft' not in ns + + +def test_environment_variables(): + # Test that the environment variables work as expected + subprocess_tests = [ + # ARRAY_API_STRICT_API_VERSION + ('''\ +import array_api_strict as xp +assert xp.__array_api_version__ == '2022.12' + +assert xp.get_array_api_strict_flags()['api_version'] == '2022.12' + +''', {}), + *[ + (f'''\ +import array_api_strict as xp +assert xp.__array_api_version__ == '{version}' + +assert xp.get_array_api_strict_flags()['api_version'] == '{version}' + +if {version} == '2021.12': + assert hasattr(xp, 'linalg') + assert not hasattr(xp, 'fft') + +''', {"ARRAY_API_STRICT_API_VERSION": version}) for version in ('2021.12', '2022.12', '2023.12')], + + # ARRAY_API_STRICT_BOOLEAN_INDEXING + ('''\ +import array_api_strict as xp + +a = xp.ones(3) +mask = xp.asarray([True, False, True]) + +assert xp.all(a[mask] == xp.asarray([1., 1.])) +assert xp.get_array_api_strict_flags()['boolean_indexing'] == True +''', {}), + *[(f'''\ +import array_api_strict as xp + +a = xp.ones(3) +mask = xp.asarray([True, False, True]) + +if {boolean_indexing}: + assert xp.all(a[mask] == xp.asarray([1., 1.])) +else: + try: + a[mask] + except RuntimeError: + pass + else: + assert False + +assert xp.get_array_api_strict_flags()['boolean_indexing'] == {boolean_indexing} +''', {"ARRAY_API_STRICT_BOOLEAN_INDEXING": boolean_indexing}) + for boolean_indexing in ('True', 'False')], + + # ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES + ('''\ +import array_api_strict as xp + +a = xp.ones(3) +xp.unique_all(a) + +assert xp.get_array_api_strict_flags()['data_dependent_shapes'] == True +''', {}), + *[(f'''\ +import array_api_strict as xp + +a = xp.ones(3) +if {data_dependent_shapes}: + xp.unique_all(a) +else: + try: + xp.unique_all(a) + except RuntimeError: + pass + else: + assert False + +assert xp.get_array_api_strict_flags()['data_dependent_shapes'] == {data_dependent_shapes} +''', {"ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES": data_dependent_shapes}) + for data_dependent_shapes in ('True', 'False')], + + # ARRAY_API_STRICT_ENABLED_EXTENSIONS + ('''\ +import array_api_strict as xp +assert hasattr(xp, 'linalg') +assert hasattr(xp, 'fft') + +assert xp.get_array_api_strict_flags()['enabled_extensions'] == ('linalg', 'fft') +''', {}), + *[(f'''\ +import array_api_strict as xp + +assert hasattr(xp, 'linalg') == ('linalg' in {extensions.split(',')}) +assert hasattr(xp, 'fft') == ('fft' in {extensions.split(',')}) + +assert sorted(xp.get_array_api_strict_flags()['enabled_extensions']) == {sorted(set(extensions.split(','))-{''})} +''', {"ARRAY_API_STRICT_ENABLED_EXTENSIONS": extensions}) + for extensions in ('', 'linalg', 'fft', 'linalg,fft')], + ] + + for test, env in subprocess_tests: + try: + subprocess.run([sys.executable, '-c', test], check=True, + capture_output=True, encoding='utf-8', env=env) + except subprocess.CalledProcessError as e: + print(e.stdout, end='') + # Ensure the exception is shown in the output log + raise AssertionError(f"""\ +STDOUT: +{e.stderr} + +STDERR: +{e.stderr} + +TEST: +{test} + +ENV: +{env}""") From a8f8fdcd665ed664b5335afda0b909c411529c2a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 6 May 2024 15:18:41 -0600 Subject: [PATCH 085/252] More than signature tests are now implemented for 2023.12 --- .github/workflows/array-api-tests.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index b37ec04..af91d2a 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -52,11 +52,6 @@ jobs: run: | # Parameterizing this in the CI matrix is wasteful. Just do a loop here. for ARRAY_API_STRICT_API_VERSION in ${API_VERSIONS}; do - # Only signature tests work for now for 2023.12 - if [[ "$ARRAY_API_STRICT_API_VERSION" == "2023.12" ]]; then - PYTEST_ARGS="${PYTEST_ARGS} -k signature - fi - cd ${GITHUB_WORKSPACE}/array-api-tests pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS} done From 752b70667aea493b417ee6464080813ce78c4c01 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 16 May 2024 14:08:49 -0600 Subject: [PATCH 086/252] Add more info to an error message --- array_api_strict/_creation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 0e85cdc..67ba67c 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -20,7 +20,7 @@ def _check_valid_dtype(dtype): # Note: Only spelling dtypes as the dtype objects is supported. if dtype not in (None,) + _all_dtypes: - raise ValueError("dtype must be one of the supported dtypes") + raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}") def _supports_buffer_protocol(obj): try: From 7cb321498647c3e17ccf7156b5626609e30e1e19 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 17 May 2024 15:06:52 -0600 Subject: [PATCH 087/252] Fix some issues with cumulative_sum - The behavior for dtype=None was incorrect. - Fix an error with axis=-1, include_initial=True. --- array_api_strict/_statistical_functions.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 7a42d25..39e3736 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -30,8 +30,9 @@ def cumulative_sum( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in cumulative_sum") - if dtype is None: - dtype = x.dtype + dt = x.dtype if dtype is None else dtype + if dtype is not None: + dtype = dtype._np_dtype # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: @@ -40,8 +41,10 @@ def cumulative_sum( axis = 0 # np.cumsum does not support include_initial if include_initial: - x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dtype), x], axis=axis) - return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype._np_dtype)) + if axis < 0: + axis += x.ndim + x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis) + return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype)) def max( x: Array, From beb95ae1857d3cff554ead2f8cc1c23964cfa8e6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 24 May 2024 16:07:02 -0600 Subject: [PATCH 088/252] Fix typo --- array_api_strict/_linear_algebra_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py index 6a1a921..dcb654d 100644 --- a/array_api_strict/_linear_algebra_functions.py +++ b/array_api_strict/_linear_algebra_functions.py @@ -61,7 +61,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: elif axis < min(-1, -x1.ndim, -x2.ndim): raise ValueError("axis is out of bounds for x1 and x2") - # In versions if the standard prior to 2023.12, vecdot applied axis after + # In versions of the standard prior to 2023.12, vecdot applied axis after # broadcasting. This is different from applying it before broadcasting # when axis is nonnegative. The below code keeps this behavior for # 2022.12, primarily for backwards compatibility. Note that the behavior From c721f3d552c7203a7233075d1cf13063d7128039 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 28 May 2024 15:36:04 -0600 Subject: [PATCH 089/252] Remove duplicate __all__ definition from _info.py --- array_api_strict/_info.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index 5f8c841..ab5447a 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -1,14 +1,5 @@ from __future__ import annotations -__all__ = [ - "__array_namespace_info__", - "capabilities", - "default_device", - "default_dtypes", - "devices", - "dtypes", -] - from typing import TYPE_CHECKING if TYPE_CHECKING: From b1e527bc15aadfdcc577cfedf9099ff204b95e60 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 Jun 2024 21:43:44 +0000 Subject: [PATCH 090/252] Bump dawidd6/action-download-artifact from 3 to 4 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 3 to 4 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v3...v4) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 79cde63..7c301e5 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v3 + uses: dawidd6/action-download-artifact@v4 with: workflow: docs-build.yml name: docs-build From 935ea8722a5ef1d06786e1b148003d72638a3d19 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 7 Jun 2024 14:45:19 -0600 Subject: [PATCH 091/252] Disable array iteration Fixes #41. --- array_api_strict/_array_object.py | 9 +++++++++ array_api_strict/tests/test_array_object.py | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 8849ce3..18ed327 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -647,6 +647,15 @@ def __invert__(self: Array, /) -> Array: res = self._array.__invert__() return self.__class__._new(res) + def __iter__(self: Array, /): + """ + Performs the operation __iter__. + """ + # Manually disable iteration, since __getitem__ raises IndexError on + # things like ones((3, 3))[0], which causes list(ones((3, 3))) to give + # []. + raise TypeError("array iteration is not allowed in array-api-strict") + def __le__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __le__. diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index e061a94..407bff2 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -416,3 +416,7 @@ def test_array_namespace(): pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2023.12")) + +def test_no_iter(): + pytest.raises(TypeError, lambda: iter(ones(3))) + pytest.raises(TypeError, lambda: iter(ones((3, 3)))) From f1d9418548d73d2a12f6e9dcdc6bf958fffef5d3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 10 Jun 2024 21:25:57 +0000 Subject: [PATCH 092/252] Bump dawidd6/action-download-artifact from 4 to 5 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 4 to 5 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v4...v5) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 7c301e5..1810d44 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v4 + uses: dawidd6/action-download-artifact@v5 with: workflow: docs-build.yml name: docs-build From e379a636f3b37b8450c82ed24c767063a88ceb1f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 17 Jun 2024 21:12:56 +0000 Subject: [PATCH 093/252] Bump the actions group with 2 updates Bumps the actions group with 2 updates: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact) and [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish). Updates `dawidd6/action-download-artifact` from 5 to 6 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v5...v6) Updates `pypa/gh-action-pypi-publish` from 1.8.14 to 1.9.0 - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.8.14...v1.9.0) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-minor dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- .github/workflows/publish-package.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 1810d44..9aa379d 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v5 + uses: dawidd6/action-download-artifact@v6 with: workflow: docs-build.yml name: docs-build diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index bfe98bb..5ebe50c 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -97,7 +97,7 @@ jobs: if: >- (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - uses: pypa/gh-action-pypi-publish@v1.8.14 + uses: pypa/gh-action-pypi-publish@v1.9.0 with: repository-url: https://test.pypi.org/legacy/ print-hash: true @@ -110,6 +110,6 @@ jobs: - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@v1.8.14 + uses: pypa/gh-action-pypi-publish@v1.9.0 with: print-hash: true From 5cf028c51fa22f2926d3a85deacf21c5815d420c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 26 Jun 2024 15:11:38 -0600 Subject: [PATCH 094/252] Fix NumPy 1.26 type promotion in copysign --- array_api_strict/_elementwise_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 9ef71bd..b39bd86 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -328,6 +328,9 @@ def copysign(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in copysign") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.copysign(x1._array, x2._array)) def cos(x: Array, /) -> Array: From 5e607c3f72bcf32443970f4c2b03757f02e6bdd0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 26 Jun 2024 15:19:51 -0600 Subject: [PATCH 095/252] Remove NPY_PROMOTION_STATE=weak from the CI The strict library should be explicitly working around all the bad promotion issues from NumPy 1.26. --- .github/workflows/array-api-tests.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index af91d2a..ab7dbb8 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -46,9 +46,6 @@ jobs: - name: Run the array API testsuite env: ARRAY_API_TESTS_MODULE: array_api_strict - # This enables the NEP 50 type promotion behavior (without it a lot of - # tests fail in numpy 1.26 on bad scalar type promotion behavior) - NPY_PROMOTION_STATE: weak run: | # Parameterizing this in the CI matrix is wasteful. Just do a loop here. for ARRAY_API_STRICT_API_VERSION in ${API_VERSIONS}; do From 6f8c07f548d4e90bcfae71e74a10b0043385adb6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 26 Jun 2024 17:11:12 -0600 Subject: [PATCH 096/252] Trigger CI From a3f6ef690986eecbcb3a6fbfb5d2a5a17fa6a3f3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 16:04:22 -0600 Subject: [PATCH 097/252] Fix docs build issues --- docs/api.rst | 3 ++- docs/index.md | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index ed702dc..760d11f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -3,6 +3,8 @@ API Reference .. automodule:: array_api_strict +.. _array-api-strict-flags: + Array API Strict Flags ---------------------- @@ -12,7 +14,6 @@ Array API Strict Flags .. autofunction:: get_array_api_strict_flags -.. _set_array_api_strict_flags: .. autofunction:: set_array_api_strict_flags .. autofunction:: reset_array_api_strict_flags .. autoclass:: ArrayAPIStrictFlags diff --git a/docs/index.md b/docs/index.md index fc385d4..3a2edd5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,7 +20,7 @@ array-api-strict currently supports the version of the standard. Experimental [2023.12](https://data-apis.org/array-api/latest/changelog.html#v2023-12) support is implemented, [but must be enabled with a -flag](set_array_api_strict_flags). +flag](array-api-strict-flags). ## Install @@ -185,7 +185,7 @@ issue, but this hasn't necessarily been tested thoroughly. 6. array-api-strict currently uses the 2022.12 version of the array API standard. Support for 2023.12 is implemented but is still experimental and not fully tested. It can be enabled with - [`array_api_strict.set_array_api_strict_flags(api_version='2023.12')`](set_array_api_strict_flags). + {func}`array_api_strict.set_array_api_strict_flags(api_version='2023.12') `. (numpy.array_api)= From 4a02c5a4120a77c1c31c88afe536911c529709f2 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 16:12:47 -0600 Subject: [PATCH 098/252] Use numpydoc-style parameters for set_array_api_strict_flags --- array_api_strict/_flags.py | 84 ++++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 40 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index f6cef29..1694650 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -64,46 +64,50 @@ def set_array_api_strict_flags( This function is **not** part of the array API standard. It only exists in array-api-strict. - - `api_version`: The version of the standard to use. Supported - versions are: ``{supported_versions}``. The default version number is - ``{default_version!r}``. - - Note that 2021.12 is supported, but currently gives the same thing as - 2022.12 (except that the fft extension will be disabled). - - 2023.12 support is experimental. Some features in 2023.12 may still be - missing, and it hasn't been fully tested. - - - `boolean_indexing`: Whether indexing by a boolean array is supported. - Note that although boolean array indexing does result in data-dependent - shapes, this flag is independent of the `data_dependent_shapes` flag - (see below). - - - `data_dependent_shapes`: Whether data-dependent shapes are enabled in - array-api-strict. - - This flag is enabled by default. Array libraries that use computation - graphs may not be able to support functions whose output shapes depend - on the input data. - - The functions that make use of data-dependent shapes, and are therefore - disabled by setting this flag to False are - - - `unique_all()`, `unique_counts()`, `unique_inverse()`, and `unique_values()`. - - `nonzero()` - - `repeat()` when the `repeats` argument is an array (requires 2023.12 - version of the standard) - - Note that while boolean indexing is also data-dependent, it is - controlled by a separate `boolean_indexing` flag (see above). - - See - https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html - for more details. - - - `enabled_extensions`: A list of extensions that are enabled in - array-api-strict. The default is ``{default_extensions}``. Note that - some extensions require a minimum version of the standard. + Parameters + ---------- + api_version : str, optional + The version of the standard to use. Supported versions are: + ``{supported_versions}``. The default version number is + ``{default_version!r}``. + + Note that 2021.12 is supported, but currently gives the same thing as + 2022.12 (except that the fft extension will be disabled). + + 2023.12 support is experimental. Some features in 2023.12 may still be + missing, and it hasn't been fully tested. + + boolean_indexing : bool, optional + Whether indexing by a boolean array is supported. + Note that although boolean array indexing does result in + data-dependent shapes, this flag is independent of the + `data_dependent_shapes` flag (see below). + + data_dependent_shapes : bool, optional + Whether data-dependent shapes are enabled in array-api-strict. + This flag is enabled by default. Array libraries that use computation + graphs may not be able to support functions whose output shapes depend + on the input data. + + The functions that make use of data-dependent shapes, and are therefore + disabled by setting this flag to False are + + - `unique_all()`, `unique_counts()`, `unique_inverse()`, and `unique_values()`. + - `nonzero()` + - `repeat()` when the `repeats` argument is an array (requires 2023.12 + version of the standard) + + Note that while boolean indexing is also data-dependent, it is + controlled by a separate `boolean_indexing` flag (see above). + + See + https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html + for more details. + + enabled_extensions : list of str, optional + A list of extensions that are enabled in array-api-strict. The default + is ``{default_extensions}``. Note that some extensions require a + minimum version of the standard. The flags can also be changed by setting :ref:`environment variables `. From 64e7236ddeef458124f44a1c978d0e98ce1b8223 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 16:13:29 -0600 Subject: [PATCH 099/252] Note about the default version changing in the future --- array_api_strict/_flags.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 1694650..c61ea1b 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -75,7 +75,8 @@ def set_array_api_strict_flags( 2022.12 (except that the fft extension will be disabled). 2023.12 support is experimental. Some features in 2023.12 may still be - missing, and it hasn't been fully tested. + missing, and it hasn't been fully tested. A future version of + array-api-strict will change the default version to 2023.12. boolean_indexing : bool, optional Whether indexing by a boolean array is supported. From 0b47a05c61b1eb34b7385954943fd8ff617a86bd Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 16:23:08 -0600 Subject: [PATCH 100/252] Fix issues with set_array_api_strict_flags docstring --- array_api_strict/_flags.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index c61ea1b..8b46374 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -59,6 +59,8 @@ def set_array_api_strict_flags( Flags are global variables that enable or disable array-api-strict behaviors. + The flags can also be changed by setting :ref:`environment variables `. + .. note:: This function is **not** part of the array API standard. It only exists @@ -95,8 +97,8 @@ def set_array_api_strict_flags( - `unique_all()`, `unique_counts()`, `unique_inverse()`, and `unique_values()`. - `nonzero()` - - `repeat()` when the `repeats` argument is an array (requires 2023.12 - version of the standard) + - `repeat()` when the `repeats` argument is an array (requires the + 2023.12 version of the standard) Note that while boolean indexing is also data-dependent, it is controlled by a separate `boolean_indexing` flag (see above). @@ -110,8 +112,6 @@ def set_array_api_strict_flags( is ``{default_extensions}``. Note that some extensions require a minimum version of the standard. - The flags can also be changed by setting :ref:`environment variables - `. Examples -------- From 8716970263f1be6ad294aec25a95341204eef00d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 16:26:03 -0600 Subject: [PATCH 101/252] Add changelog entries for 2.0 release --- docs/changelog.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index 04c383d..3d5bf32 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,33 @@ # Changelog +## 2.0 (2024-06-27) + +### Major Changes + +- array-api-strict has a new set of [flags](array-api-strict-flags) that can + be used to dynamically enable or disable features in array-api-strict. These + flags allow you to change the supported array API version, enable or disable + [extensions](https://data-apis.org/array-api/latest/extensions/index.html), + enable or disable features that rely on data-dependent shapes, and enable or + disable boolean indexing. Future versions may add additional flags to allow + changing other optional or varying behavior in the standard. + +- Added experimental support for the + [2023.12](https://data-apis.org/array-api/2023.12/changelog.html#v2023-12) + version of the array API standard. The default version is still 2022.12, but + the version can be changed to 2023.12 using the aforementioned flags, either + by calling + {func}`array_api_strict.set_array_api_strict_flags(api_version='2023.12') + ` or setting the environment + variable {envvar}`ARRAY_API_STRICT_API_VERSION=2023.12 + `. + +### Minor Changes + +- Calling `iter()` on an array now correctly raises `TypeError`. + +- Add some missing names to `__all__`. + ## 1.1.1 (2024-04-29) - Fix the `api_version` argument to `__array_namespace__` to accept From a2cbb4e356a6948deea82797fc85e1d9628c4e5d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 16:31:06 -0600 Subject: [PATCH 102/252] Small changes to the CHANGELOG --- docs/changelog.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 3d5bf32..9c5da3c 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -10,7 +10,7 @@ [extensions](https://data-apis.org/array-api/latest/extensions/index.html), enable or disable features that rely on data-dependent shapes, and enable or disable boolean indexing. Future versions may add additional flags to allow - changing other optional or varying behavior in the standard. + changing other optional or varying behaviors in the standard. - Added experimental support for the [2023.12](https://data-apis.org/array-api/2023.12/changelog.html#v2023-12) @@ -20,7 +20,8 @@ {func}`array_api_strict.set_array_api_strict_flags(api_version='2023.12') ` or setting the environment variable {envvar}`ARRAY_API_STRICT_API_VERSION=2023.12 - `. + `. A future version of array-api-strict will + change the default version to 2023.12. ### Minor Changes From b2a395e9b509e15664f212241803fa4c0af94d74 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 16:36:57 -0600 Subject: [PATCH 103/252] NumPy 2.0 is no longer upcoming --- docs/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index 3a2edd5..a9578a1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -37,7 +37,7 @@ and [Conda-forge](https://anaconda.org/conda-forge/array-api-strict) conda install --channel conda-forge array-api-strict ``` -array-api-strict supports NumPy 1.26 and (the upcoming) NumPy 2.0. +array-api-strict supports NumPy 1.26 and NumPy 2.0. ## Rationale From c68650e29bd31f4335b4d872f46293968d107c04 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 16:37:07 -0600 Subject: [PATCH 104/252] Update some text in the docs index page --- docs/index.md | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/docs/index.md b/docs/index.md index a9578a1..f37479c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -170,23 +170,17 @@ issue, but this hasn't necessarily been tested thoroughly. this deviation may be tested with type checking. This [behavior may improve in the future](https://github.com/data-apis/array-api-strict/issues/6). -5. There are some behaviors in the standard that are not required to be - implemented by libraries that cannot support [data dependent - shapes](https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html). - This includes [the `unique_*` - functions](https://data-apis.org/array-api/latest/API_specification/set_functions.html), - [boolean array - indexing](https://data-apis.org/array-api/latest/API_specification/indexing.html#boolean-array-indexing), - and the - [`nonzero`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.nonzero.html) - function. array-api-strict currently implements all of these. In the - future, [there may be a way to disable them](https://github.com/data-apis/array-api-strict/issues/7). +5. By default, all extensions in the standard are enabled, as well as optional + behaviors such as data-dependent shapes and boolean indexing. These can be + disabled with the [array-api-strict flags](array-api-strict-flags). 6. array-api-strict currently uses the 2022.12 version of the array API - standard. Support for 2023.12 is implemented but is still experimental and - not fully tested. It can be enabled with - {func}`array_api_strict.set_array_api_strict_flags(api_version='2023.12') `. - + standard by default. Support for 2023.12 is implemented but is still + experimental and not fully tested. It can be enabled with + {func}`array_api_strict.set_array_api_strict_flags(api_version='2023.12') + ` or by setting the + environment variable {envvar}`ARRAY_API_STRICT_API_VERSION=2023.12 + `. (numpy.array_api)= ## Relationship to `numpy.array_api` From 50a16699eeea8466fac3a6a5f5083e81482501af Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 16:39:41 -0600 Subject: [PATCH 105/252] Move some text around in the index page --- docs/index.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/index.md b/docs/index.md index f37479c..0465a8d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -131,6 +131,12 @@ strict/minimal: functions](https://data-apis.org/array-api/latest/API_specification/creation_functions.html) such as `asarray`. +- Optional behavior such as [optional + extensions](https://data-apis.org/array-api/latest/extensions/index.html), + functions that use data-dependent shapes, and boolean indexing are enabled + by default but can disabled with the [array-api-strict + flags](array-api-strict-flags). + ## Caveats array-api-strict is a thin pure Python wrapper around NumPy. NumPy 2.0 fully @@ -170,11 +176,7 @@ issue, but this hasn't necessarily been tested thoroughly. this deviation may be tested with type checking. This [behavior may improve in the future](https://github.com/data-apis/array-api-strict/issues/6). -5. By default, all extensions in the standard are enabled, as well as optional - behaviors such as data-dependent shapes and boolean indexing. These can be - disabled with the [array-api-strict flags](array-api-strict-flags). - -6. array-api-strict currently uses the 2022.12 version of the array API +5. array-api-strict currently uses the 2022.12 version of the array API standard by default. Support for 2023.12 is implemented but is still experimental and not fully tested. It can be enabled with {func}`array_api_strict.set_array_api_strict_flags(api_version='2023.12') From eb8dd464eea5873a59f73c109d8e3347918319d7 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 16:40:56 -0600 Subject: [PATCH 106/252] Clarify some text --- docs/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index 0465a8d..a14fbcb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -109,7 +109,7 @@ strict/minimal: like `sin` will accept integer array inputs, but the [standard only requires them to accept floating-point inputs](https://data-apis.org/array-api/latest/API_specification/generated/array_api.sin.html#array_api.sin), - so in array-api-strict, `sin(integer_array)` will raise an exception. + so in array-api-strict, `sin(asarray(0))` will raise an exception. - The [indexing](https://data-apis.org/array-api/latest/API_specification/indexing.html) From e246809bcc2b802cf557f52b20024fc9480839cb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 16:48:50 -0600 Subject: [PATCH 107/252] Remove the hypothesis pin on CI --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index ab7dbb8..36ef85c 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -42,7 +42,7 @@ jobs: fi python -m pip install ${GITHUB_WORKSPACE}/array-api-strict python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt - python -m pip install hypothesis==6.97.1 + python -m pip install hypothesis - name: Run the array API testsuite env: ARRAY_API_TESTS_MODULE: array_api_strict From 34d1645b594d9851a44f42a4b429c2ccbeb5b69d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 16:49:13 -0600 Subject: [PATCH 108/252] Test the array-api-tests branch without the data_too_large health checks disabled --- .github/workflows/array-api-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 36ef85c..508ba82 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -28,6 +28,7 @@ jobs: repository: data-apis/array-api-tests submodules: 'true' path: array-api-tests + branch: revert-278-suppress-data_too_large - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: From 43f5a6ecd5876a928cb0db57dd0615283a411a17 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 27 Jun 2024 17:01:22 -0600 Subject: [PATCH 109/252] Some small fixes to docs --- array_api_strict/_flags.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 8b46374..62acddf 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -81,8 +81,8 @@ def set_array_api_strict_flags( array-api-strict will change the default version to 2023.12. boolean_indexing : bool, optional - Whether indexing by a boolean array is supported. - Note that although boolean array indexing does result in + Whether indexing by a boolean array is supported. This flag is enabled + by default. Note that although boolean array indexing does result in data-dependent shapes, this flag is independent of the `data_dependent_shapes` flag (see below). @@ -276,6 +276,19 @@ class ArrayAPIStrictFlags: See :func:`set_array_api_strict_flags` for a description of the available flags. + Examples + -------- + + >>> from array_api_strict import ArrayAPIStrictFlags, get_array_api_strict_flags + >>> with ArrayAPIStrictFlags(api_version="2022.12", boolean_indexing=False): + ... flags = get_array_api_strict_flags() + >>> flags + {'api_version': '2022.12', + 'boolean_indexing': False, + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft') + } + See Also -------- From 0a4813ea872aedb7aac05be530415ac5b7747eef Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 1 Jul 2024 11:24:09 -0600 Subject: [PATCH 110/252] Allow iteration on 1-D arrays upstream code (see scipy/scipy#21074). The standard does not define iteration, but it does define __getitem__, and the default Python __iter__ implements iteration when getitem is defined as a[0], a[1], ..., implying that iteration ought to work for 1-D arrays. Iteration is still disallowed for higher dimensional arrays, since getitem would not necessarily work with a single integer index (and this is the case that is controversial). In those cases, the new unstack() function would be preferable. At best it would be good to get upstream clarification from the standard whether iteration should always work or not before disallowing 1-D array iteration. --- array_api_strict/_array_object.py | 14 ++++++++++---- array_api_strict/tests/test_array_object.py | 10 ++++++++-- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index cc6bd1a..d8ed018 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -677,10 +677,16 @@ def __iter__(self: Array, /): """ Performs the operation __iter__. """ - # Manually disable iteration, since __getitem__ raises IndexError on - # things like ones((3, 3))[0], which causes list(ones((3, 3))) to give - # []. - raise TypeError("array iteration is not allowed in array-api-strict") + # Manually disable iteration on higher dimensional arrays, since + # __getitem__ raises IndexError on things like ones((3, 3))[0], which + # causes list(ones((3, 3))) to give []. + if self.ndim > 1: + raise TypeError("array iteration is not allowed in array-api-strict") + # Allow iteration for 1-D arrays. The array API doesn't strictly + # define __iter__, but it doesn't disallow it. The default Python + # behavior is to implement iter as a[0], a[1], ... when __getitem__ is + # implemented, which implies iteration on 1-D arrays. + return (Array._new(i) for i in self._array) def __le__(self: Array, other: Union[int, float, Array], /) -> Array: """ diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index b28c747..b0d4868 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -1,4 +1,5 @@ import operator +from builtins import all as all_ from numpy.testing import assert_raises, suppress_warnings import numpy as np @@ -21,6 +22,7 @@ int32, int64, uint64, + float64, bool as bool_, ) from .._flags import set_array_api_strict_flags @@ -423,8 +425,12 @@ def test_array_namespace(): pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) -def test_no_iter(): - pytest.raises(TypeError, lambda: iter(ones(3))) +def test_iter(): + pytest.raises(TypeError, lambda: iter(asarray(3))) + assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)] + assert all_(isinstance(a, Array) for a in iter(ones(3))) + assert all_(a.shape == () for a in iter(ones(3))) + assert all_(a.dtype == float64 for a in iter(ones(3))) pytest.raises(TypeError, lambda: iter(ones((3, 3)))) @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) From e88bd9987dc9bca92a3ba0dbef27194b15bb9c62 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 1 Jul 2024 12:13:15 -0600 Subject: [PATCH 111/252] Revert "Test the array-api-tests branch without the data_too_large health checks disabled" This reverts commit 34d1645b594d9851a44f42a4b429c2ccbeb5b69d. --- .github/workflows/array-api-tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 508ba82..36ef85c 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -28,7 +28,6 @@ jobs: repository: data-apis/array-api-tests submodules: 'true' path: array-api-tests - branch: revert-278-suppress-data_too_large - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: From 9c9ff8b82976d9ef98e0812c30c5099e6aed9e8f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 1 Jul 2024 13:25:45 -0600 Subject: [PATCH 112/252] Add a changelog for 2.0.1 --- docs/changelog.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index 9c5da3c..7b40fe3 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,14 @@ # Changelog +### 2.0.1 (2024-07-01) + +## Minor Changes + +- Re-allow iteration on 1-D arrays. A change from 2.0 fixed iter() raising on + n-D arrays but also made 1-D arrays raise. The standard does not explicitly + disallow iteration on 1-D arrays, and the default Python `__iter__` + implementation allows it to work, so for now, it is kept intact as working. + ## 2.0 (2024-06-27) ### Major Changes From 5379bd5796c6c7039033c6f59e2b5923f3fc9162 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 10 Jul 2024 13:01:20 -0600 Subject: [PATCH 113/252] Allow any combination of real dtypes in comparisons This does not change == or != because the standard is currently unclear about that so I'd like to see what happens there first. --- array_api_strict/_array_object.py | 19 +++++++++++++------ array_api_strict/_elementwise_functions.py | 8 -------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index d8ed018..86d2b8b 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -152,7 +152,13 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None # spec in places where it either deviates from or is more strict than # NumPy behavior - def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array: + def _check_allowed_dtypes( + self, + other: bool | int | float | Array, + dtype_category: str, + op: str, + check_promotion: bool = True, + ) -> Array: """ Helper function for operators to only allow specific input dtypes @@ -176,7 +182,8 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor # This will raise TypeError for type combinations that are not allowed # to promote in the spec (even if the NumPy array operator would # promote them). - res_dtype = _result_type(self.dtype, other.dtype) + if check_promotion: + res_dtype = _result_type(self.dtype, other.dtype) if op.startswith("__i"): # Note: NumPy will allow in-place operators in some cases where # the type promoted operator does not match the left-hand side @@ -604,7 +611,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ge__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__ge__") + other = self._check_allowed_dtypes(other, "real numeric", "__ge__", check_promotion=False) if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -638,7 +645,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __gt__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__gt__") + other = self._check_allowed_dtypes(other, "real numeric", "__gt__", check_promotion=False) if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -692,7 +699,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __le__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__le__") + other = self._check_allowed_dtypes(other, "real numeric", "__le__", check_promotion=False) if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -714,7 +721,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __lt__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__lt__") + other = self._check_allowed_dtypes(other, "real numeric", "__lt__", check_promotion=False) if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index b39bd86..f0c94db 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -439,8 +439,6 @@ def greater(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater(x1._array, x2._array)) @@ -453,8 +451,6 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater_equal") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater_equal(x1._array, x2._array)) @@ -524,8 +520,6 @@ def less(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less(x1._array, x2._array)) @@ -538,8 +532,6 @@ def less_equal(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less_equal") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less_equal(x1._array, x2._array)) From b47039fbe6bb6b67a8be1d739a9b81c24e42e3f2 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 10 Jul 2024 13:24:13 -0600 Subject: [PATCH 114/252] Add helpful error messages to assert_raises calls in test_array_object.py --- array_api_strict/tests/test_array_object.py | 65 ++++++++++----------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index b0d4868..d4b8794 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -1,7 +1,7 @@ import operator from builtins import all as all_ -from numpy.testing import assert_raises, suppress_warnings +import numpy.testing import numpy as np import pytest @@ -29,6 +29,10 @@ import array_api_strict +def assert_raises(exception, func, msg=None): + with numpy.testing.assert_raises(exception, msg=msg): + func() + def test_validate_index(): # The indexing tests in the official array API test suite test that the # array object correctly handles the subset of indices that are required @@ -111,6 +115,7 @@ def test_operators(): "__truediv__": "floating", "__xor__": "integer_or_boolean", } + comparison_ops = ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"] # Recompute each time because of in-place ops def _array_vals(): for d in _integer_dtypes: @@ -124,7 +129,7 @@ def _array_vals(): BIG_INT = int(1e30) for op, dtypes in binary_op_dtypes.items(): ops = [op] - if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]: + if op not in comparison_ops: rop = "__r" + op[2:] iop = "__i" + op[2:] ops += [rop, iop] @@ -155,16 +160,16 @@ def _array_vals(): or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] )): if a.dtype in _integer_dtypes and s == BIG_INT: - assert_raises(OverflowError, lambda: getattr(a, _op)(s)) + assert_raises(OverflowError, lambda: getattr(a, _op)(s), _op) else: # Only test for no error - with suppress_warnings() as sup: + with numpy.testing.suppress_warnings() as sup: # ignore warnings from pow(BIG_INT) sup.filter(RuntimeWarning, "invalid value encountered in power") getattr(a, _op)(s) else: - assert_raises(TypeError, lambda: getattr(a, _op)(s)) + assert_raises(TypeError, lambda: getattr(a, _op)(s), _op) # Test array op array. for _op in ops: @@ -188,7 +193,7 @@ def _array_vals(): _op.startswith("__i") and result_type(x.dtype, y.dtype) != x.dtype ): - assert_raises(TypeError, lambda: getattr(x, _op)(y)) + assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) # Ensure only those dtypes that are required for every operator are allowed. elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) @@ -202,7 +207,7 @@ def _array_vals(): ): getattr(x, _op)(y) else: - assert_raises(TypeError, lambda: getattr(x, _op)(y)) + assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) unary_op_dtypes = { "__abs__": "numeric", @@ -221,7 +226,7 @@ def _array_vals(): # Only test for no error getattr(a, op)() else: - assert_raises(TypeError, lambda: getattr(a, op)()) + assert_raises(TypeError, lambda: getattr(a, op)(), _op) # Finally, matmul() must be tested separately, because it works a bit # different from the other operations. @@ -240,9 +245,9 @@ def _matmul_array_vals(): or type(s) == int and a.dtype in _integer_dtypes): # Type promotion is valid, but @ is not allowed on 0-D # inputs, so the error is a ValueError - assert_raises(ValueError, lambda: getattr(a, _op)(s)) + assert_raises(ValueError, lambda: getattr(a, _op)(s), _op) else: - assert_raises(TypeError, lambda: getattr(a, _op)(s)) + assert_raises(TypeError, lambda: getattr(a, _op)(s), _op) for x in _matmul_array_vals(): for y in _matmul_array_vals(): @@ -356,20 +361,17 @@ def test_allow_newaxis(): def test_disallow_flat_indexing_with_newaxis(): a = ones((3, 3, 3)) - with pytest.raises(IndexError): - a[None, 0, 0] + assert_raises(IndexError, lambda: a[None, 0, 0]) def test_disallow_mask_with_newaxis(): a = ones((3, 3, 3)) - with pytest.raises(IndexError): - a[None, asarray(True)] + assert_raises(IndexError, lambda: a[None, asarray(True)]) @pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)]) @pytest.mark.parametrize("index", ["string", False, True]) def test_error_on_invalid_index(shape, index): a = ones(shape) - with pytest.raises(IndexError): - a[index] + assert_raises(IndexError, lambda: a[index]) def test_mask_0d_array_without_errors(): a = ones(()) @@ -380,10 +382,8 @@ def test_mask_0d_array_without_errors(): ) def test_error_on_invalid_index_with_ellipsis(i): a = ones((3, 3, 3)) - with pytest.raises(IndexError): - a[..., i] - with pytest.raises(IndexError): - a[i, ...] + assert_raises(IndexError, lambda: a[..., i]) + assert_raises(IndexError, lambda: a[i, ...]) def test_array_keys_use_private_array(): """ @@ -400,8 +400,7 @@ def test_array_keys_use_private_array(): a = ones((0,), dtype=bool_) key = ones((0, 0), dtype=bool_) - with pytest.raises(IndexError): - a[key] + assert_raises(IndexError, lambda: a[key]) def test_array_namespace(): a = ones((3, 3)) @@ -422,16 +421,16 @@ def test_array_namespace(): assert a.__array_namespace__(api_version="2021.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2021.12" - pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) - pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) + assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) + assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) def test_iter(): - pytest.raises(TypeError, lambda: iter(asarray(3))) + assert_raises(TypeError, lambda: iter(asarray(3))) assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)] assert all_(isinstance(a, Array) for a in iter(ones(3))) assert all_(a.shape == () for a in iter(ones(3))) assert all_(a.dtype == float64 for a in iter(ones(3))) - pytest.raises(TypeError, lambda: iter(ones((3, 3)))) + assert_raises(TypeError, lambda: iter(ones((3, 3)))) @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) def dlpack_2023_12(api_version): @@ -447,17 +446,17 @@ def dlpack_2023_12(api_version): exception = NotImplementedError if api_version >= '2023.12' else ValueError - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(dl_device=CPU_DEVICE)) - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(dl_device=None)) - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(max_version=(1, 0))) - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(max_version=None)) - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(copy=False)) - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(copy=True)) - pytest.raises(exception, lambda: + assert_raises(exception, lambda: a.__dlpack__(copy=None)) From 1f8769914880a626eb1bca6ccbe0399a47d73694 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 10 Jul 2024 13:33:21 -0600 Subject: [PATCH 115/252] Allow all dtypes in equal/not_equal/==/!= and update tests Also update elementwise function tests to check for disallowed type promotions, not just disallowed mixed kind types. --- array_api_strict/_array_object.py | 5 ++-- array_api_strict/_elementwise_functions.py | 4 --- array_api_strict/tests/test_array_object.py | 28 ++++++++--------- .../tests/test_elementwise_functions.py | 30 +++++++++++++++++-- 4 files changed, 45 insertions(+), 22 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 86d2b8b..bded0c6 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -157,6 +157,7 @@ def _check_allowed_dtypes( other: bool | int | float | Array, dtype_category: str, op: str, + *, check_promotion: bool = True, ) -> Array: """ @@ -577,7 +578,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. - other = self._check_allowed_dtypes(other, "all", "__eq__") + other = self._check_allowed_dtypes(other, "all", "__eq__", check_promotion=False) if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -766,7 +767,7 @@ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ Performs the operation __ne__. """ - other = self._check_allowed_dtypes(other, "all", "__ne__") + other = self._check_allowed_dtypes(other, "all", "__ne__", check_promotion=False) if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index f0c94db..d4a108d 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -375,8 +375,6 @@ def equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.equal(x1._array, x2._array)) @@ -707,8 +705,6 @@ def not_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.not_equal(x1._array, x2._array)) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index d4b8794..04e606e 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -94,7 +94,7 @@ def test_validate_index(): def test_operators(): # For every operator, we test that it works for the required type - # combinations and raises TypeError otherwise + # combinations and assert_raises TypeError otherwise binary_op_dtypes = { "__add__": "numeric", "__and__": "integer_or_boolean", @@ -178,16 +178,17 @@ def _array_vals(): # See the promotion table in NEP 47 or the array # API spec page on type promotion. Mixed kind # promotion is not defined. - if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] - or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] - or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes - or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes - or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes - or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes - or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes - or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes - ): - assert_raises(TypeError, lambda: getattr(x, _op)(y)) + if (op not in comparison_ops and + (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes + or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + )): + assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) # Ensure in-place operators only promote to the same dtype as the left operand. elif ( _op.startswith("__i") @@ -195,8 +196,7 @@ def _array_vals(): ): assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) # Ensure only those dtypes that are required for every operator are allowed. - elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes - or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) + elif (dtypes == "all" or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes) or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes @@ -207,7 +207,7 @@ def _array_vals(): ): getattr(x, _op)(y) else: - assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) + assert_raises(TypeError, lambda: getattr(x, _op)(y), (x, _op, y)) unary_op_dtypes = { "__abs__": "numeric", diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 90994f3..92c9c59 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,6 +1,6 @@ from inspect import getfullargspec, getmodule -from numpy.testing import assert_raises +from .test_array_object import assert_raises from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift @@ -9,6 +9,11 @@ _boolean_dtypes, _floating_dtypes, _integer_dtypes, + int8, + int16, + int32, + int64, + uint64, ) from .._flags import set_array_api_strict_flags @@ -86,6 +91,15 @@ def nargs(func): "trunc": "real numeric", } +comparison_functions = [ + 'equal', + 'greater', + 'greater_equal', + 'less', + 'less_equal', + 'not_equal', +] + def test_missing_functions(): # Ensure the above dictionary is complete. import array_api_strict._elementwise_functions as mod @@ -115,8 +129,20 @@ def _array_vals(): func = getattr(_elementwise_functions, func_name) if nargs(func) == 2: for y in _array_vals(): + # Disallow dtypes that aren't type promotable + if (func_name not in comparison_functions and + (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes + or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + )): + assert_raises(TypeError, lambda: func(x, y), (func_name, x, y)) if x.dtype not in dtypes or y.dtype not in dtypes: - assert_raises(TypeError, lambda: func(x, y)) + assert_raises(TypeError, lambda: func(x, y), (func_name, x, y)) else: if x.dtype not in dtypes: assert_raises(TypeError, lambda: func(x)) From 899ad12bdc365bce8d97e84b615cca9ed0ad33da Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Jul 2024 15:44:36 -0600 Subject: [PATCH 116/252] Revert "Allow any combination of real dtypes in comparisons" --- array_api_strict/_array_object.py | 24 ++--- array_api_strict/_elementwise_functions.py | 12 +++ array_api_strict/tests/test_array_object.py | 91 ++++++++++--------- .../tests/test_elementwise_functions.py | 30 +----- 4 files changed, 68 insertions(+), 89 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index bded0c6..d8ed018 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -152,14 +152,7 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None # spec in places where it either deviates from or is more strict than # NumPy behavior - def _check_allowed_dtypes( - self, - other: bool | int | float | Array, - dtype_category: str, - op: str, - *, - check_promotion: bool = True, - ) -> Array: + def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array: """ Helper function for operators to only allow specific input dtypes @@ -183,8 +176,7 @@ def _check_allowed_dtypes( # This will raise TypeError for type combinations that are not allowed # to promote in the spec (even if the NumPy array operator would # promote them). - if check_promotion: - res_dtype = _result_type(self.dtype, other.dtype) + res_dtype = _result_type(self.dtype, other.dtype) if op.startswith("__i"): # Note: NumPy will allow in-place operators in some cases where # the type promoted operator does not match the left-hand side @@ -578,7 +570,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. - other = self._check_allowed_dtypes(other, "all", "__eq__", check_promotion=False) + other = self._check_allowed_dtypes(other, "all", "__eq__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -612,7 +604,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ge__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__ge__", check_promotion=False) + other = self._check_allowed_dtypes(other, "real numeric", "__ge__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -646,7 +638,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __gt__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__gt__", check_promotion=False) + other = self._check_allowed_dtypes(other, "real numeric", "__gt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -700,7 +692,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __le__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__le__", check_promotion=False) + other = self._check_allowed_dtypes(other, "real numeric", "__le__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -722,7 +714,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __lt__. """ - other = self._check_allowed_dtypes(other, "real numeric", "__lt__", check_promotion=False) + other = self._check_allowed_dtypes(other, "real numeric", "__lt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -767,7 +759,7 @@ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ Performs the operation __ne__. """ - other = self._check_allowed_dtypes(other, "all", "__ne__", check_promotion=False) + other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index d4a108d..b39bd86 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -375,6 +375,8 @@ def equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.equal(x1._array, x2._array)) @@ -437,6 +439,8 @@ def greater(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater(x1._array, x2._array)) @@ -449,6 +453,8 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater_equal") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.greater_equal(x1._array, x2._array)) @@ -518,6 +524,8 @@ def less(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less(x1._array, x2._array)) @@ -530,6 +538,8 @@ def less_equal(x1: Array, x2: Array, /) -> Array: """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less_equal") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.less_equal(x1._array, x2._array)) @@ -705,6 +715,8 @@ def not_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.not_equal(x1._array, x2._array)) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 04e606e..b0d4868 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -1,7 +1,7 @@ import operator from builtins import all as all_ -import numpy.testing +from numpy.testing import assert_raises, suppress_warnings import numpy as np import pytest @@ -29,10 +29,6 @@ import array_api_strict -def assert_raises(exception, func, msg=None): - with numpy.testing.assert_raises(exception, msg=msg): - func() - def test_validate_index(): # The indexing tests in the official array API test suite test that the # array object correctly handles the subset of indices that are required @@ -94,7 +90,7 @@ def test_validate_index(): def test_operators(): # For every operator, we test that it works for the required type - # combinations and assert_raises TypeError otherwise + # combinations and raises TypeError otherwise binary_op_dtypes = { "__add__": "numeric", "__and__": "integer_or_boolean", @@ -115,7 +111,6 @@ def test_operators(): "__truediv__": "floating", "__xor__": "integer_or_boolean", } - comparison_ops = ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"] # Recompute each time because of in-place ops def _array_vals(): for d in _integer_dtypes: @@ -129,7 +124,7 @@ def _array_vals(): BIG_INT = int(1e30) for op, dtypes in binary_op_dtypes.items(): ops = [op] - if op not in comparison_ops: + if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]: rop = "__r" + op[2:] iop = "__i" + op[2:] ops += [rop, iop] @@ -160,16 +155,16 @@ def _array_vals(): or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] )): if a.dtype in _integer_dtypes and s == BIG_INT: - assert_raises(OverflowError, lambda: getattr(a, _op)(s), _op) + assert_raises(OverflowError, lambda: getattr(a, _op)(s)) else: # Only test for no error - with numpy.testing.suppress_warnings() as sup: + with suppress_warnings() as sup: # ignore warnings from pow(BIG_INT) sup.filter(RuntimeWarning, "invalid value encountered in power") getattr(a, _op)(s) else: - assert_raises(TypeError, lambda: getattr(a, _op)(s), _op) + assert_raises(TypeError, lambda: getattr(a, _op)(s)) # Test array op array. for _op in ops: @@ -178,25 +173,25 @@ def _array_vals(): # See the promotion table in NEP 47 or the array # API spec page on type promotion. Mixed kind # promotion is not defined. - if (op not in comparison_ops and - (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] - or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] - or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes - or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes - or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes - or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes - or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes - or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes - )): - assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) + if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes + or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + ): + assert_raises(TypeError, lambda: getattr(x, _op)(y)) # Ensure in-place operators only promote to the same dtype as the left operand. elif ( _op.startswith("__i") and result_type(x.dtype, y.dtype) != x.dtype ): - assert_raises(TypeError, lambda: getattr(x, _op)(y), _op) + assert_raises(TypeError, lambda: getattr(x, _op)(y)) # Ensure only those dtypes that are required for every operator are allowed. - elif (dtypes == "all" + elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes + or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes) or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes @@ -207,7 +202,7 @@ def _array_vals(): ): getattr(x, _op)(y) else: - assert_raises(TypeError, lambda: getattr(x, _op)(y), (x, _op, y)) + assert_raises(TypeError, lambda: getattr(x, _op)(y)) unary_op_dtypes = { "__abs__": "numeric", @@ -226,7 +221,7 @@ def _array_vals(): # Only test for no error getattr(a, op)() else: - assert_raises(TypeError, lambda: getattr(a, op)(), _op) + assert_raises(TypeError, lambda: getattr(a, op)()) # Finally, matmul() must be tested separately, because it works a bit # different from the other operations. @@ -245,9 +240,9 @@ def _matmul_array_vals(): or type(s) == int and a.dtype in _integer_dtypes): # Type promotion is valid, but @ is not allowed on 0-D # inputs, so the error is a ValueError - assert_raises(ValueError, lambda: getattr(a, _op)(s), _op) + assert_raises(ValueError, lambda: getattr(a, _op)(s)) else: - assert_raises(TypeError, lambda: getattr(a, _op)(s), _op) + assert_raises(TypeError, lambda: getattr(a, _op)(s)) for x in _matmul_array_vals(): for y in _matmul_array_vals(): @@ -361,17 +356,20 @@ def test_allow_newaxis(): def test_disallow_flat_indexing_with_newaxis(): a = ones((3, 3, 3)) - assert_raises(IndexError, lambda: a[None, 0, 0]) + with pytest.raises(IndexError): + a[None, 0, 0] def test_disallow_mask_with_newaxis(): a = ones((3, 3, 3)) - assert_raises(IndexError, lambda: a[None, asarray(True)]) + with pytest.raises(IndexError): + a[None, asarray(True)] @pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)]) @pytest.mark.parametrize("index", ["string", False, True]) def test_error_on_invalid_index(shape, index): a = ones(shape) - assert_raises(IndexError, lambda: a[index]) + with pytest.raises(IndexError): + a[index] def test_mask_0d_array_without_errors(): a = ones(()) @@ -382,8 +380,10 @@ def test_mask_0d_array_without_errors(): ) def test_error_on_invalid_index_with_ellipsis(i): a = ones((3, 3, 3)) - assert_raises(IndexError, lambda: a[..., i]) - assert_raises(IndexError, lambda: a[i, ...]) + with pytest.raises(IndexError): + a[..., i] + with pytest.raises(IndexError): + a[i, ...] def test_array_keys_use_private_array(): """ @@ -400,7 +400,8 @@ def test_array_keys_use_private_array(): a = ones((0,), dtype=bool_) key = ones((0, 0), dtype=bool_) - assert_raises(IndexError, lambda: a[key]) + with pytest.raises(IndexError): + a[key] def test_array_namespace(): a = ones((3, 3)) @@ -421,16 +422,16 @@ def test_array_namespace(): assert a.__array_namespace__(api_version="2021.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2021.12" - assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) - assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) def test_iter(): - assert_raises(TypeError, lambda: iter(asarray(3))) + pytest.raises(TypeError, lambda: iter(asarray(3))) assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)] assert all_(isinstance(a, Array) for a in iter(ones(3))) assert all_(a.shape == () for a in iter(ones(3))) assert all_(a.dtype == float64 for a in iter(ones(3))) - assert_raises(TypeError, lambda: iter(ones((3, 3)))) + pytest.raises(TypeError, lambda: iter(ones((3, 3)))) @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) def dlpack_2023_12(api_version): @@ -446,17 +447,17 @@ def dlpack_2023_12(api_version): exception = NotImplementedError if api_version >= '2023.12' else ValueError - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(dl_device=CPU_DEVICE)) - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(dl_device=None)) - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(max_version=(1, 0))) - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(max_version=None)) - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(copy=False)) - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(copy=True)) - assert_raises(exception, lambda: + pytest.raises(exception, lambda: a.__dlpack__(copy=None)) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 92c9c59..90994f3 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,6 +1,6 @@ from inspect import getfullargspec, getmodule -from .test_array_object import assert_raises +from numpy.testing import assert_raises from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift @@ -9,11 +9,6 @@ _boolean_dtypes, _floating_dtypes, _integer_dtypes, - int8, - int16, - int32, - int64, - uint64, ) from .._flags import set_array_api_strict_flags @@ -91,15 +86,6 @@ def nargs(func): "trunc": "real numeric", } -comparison_functions = [ - 'equal', - 'greater', - 'greater_equal', - 'less', - 'less_equal', - 'not_equal', -] - def test_missing_functions(): # Ensure the above dictionary is complete. import array_api_strict._elementwise_functions as mod @@ -129,20 +115,8 @@ def _array_vals(): func = getattr(_elementwise_functions, func_name) if nargs(func) == 2: for y in _array_vals(): - # Disallow dtypes that aren't type promotable - if (func_name not in comparison_functions and - (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] - or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] - or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes - or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes - or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes - or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes - or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes - or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes - )): - assert_raises(TypeError, lambda: func(x, y), (func_name, x, y)) if x.dtype not in dtypes or y.dtype not in dtypes: - assert_raises(TypeError, lambda: func(x, y), (func_name, x, y)) + assert_raises(TypeError, lambda: func(x, y)) else: if x.dtype not in dtypes: assert_raises(TypeError, lambda: func(x)) From 1c03aaaf8b8e93609b35aadee38b5e6ccdba494b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Jul 2024 15:49:09 -0600 Subject: [PATCH 117/252] Restore extra testing of elementwise function disallowed type promotions --- .../tests/test_elementwise_functions.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 90994f3..fa3405a 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -9,6 +9,11 @@ _boolean_dtypes, _floating_dtypes, _integer_dtypes, + int8, + int16, + int32, + int64, + uint64, ) from .._flags import set_array_api_strict_flags @@ -115,6 +120,17 @@ def _array_vals(): func = getattr(_elementwise_functions, func_name) if nargs(func) == 2: for y in _array_vals(): + # Disallow dtypes that aren't type promotable + if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64] + or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64] + or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes + or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes + or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes + or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes + or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes + or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes + ): + assert_raises(TypeError, lambda: func(x, y)) if x.dtype not in dtypes or y.dtype not in dtypes: assert_raises(TypeError, lambda: func(x, y)) else: From 600df5e02686a623219079c8b6a749a9398c55c8 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Wed, 4 Sep 2024 10:47:11 +0200 Subject: [PATCH 118/252] Add "multi device" support Having more than one device is useful during testing to allow you to find bugs related to how arrays on different devices are handled. --- array_api_strict/__init__.py | 3 ++ array_api_strict/_array_object.py | 28 +++++++++++----- array_api_strict/_creation_functions.py | 36 ++++++++++----------- array_api_strict/_typing.py | 4 +-- array_api_strict/tests/test_array_object.py | 2 +- 5 files changed, 44 insertions(+), 29 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 8dfa09f..ff43660 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -309,6 +309,9 @@ __all__ += ["all", "any"] +from ._array_object import Device +__all__ += ["Device"] + # Helper functions that are not part of the standard from ._flags import ( diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index d8ed018..72efd25 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -43,13 +43,17 @@ import numpy as np -# Placeholder object to represent the "cpu" device (the only device NumPy -# supports). -class _cpu_device: +class Device: + def __init__(self, device="CPU_DEVICE"): + self._device = device + def __repr__(self): - return "CPU_DEVICE" + return f"Device('{self._device}')" + + def __eq__(self, other): + return self._device == other._device -CPU_DEVICE = _cpu_device() +CPU_DEVICE = Device() _default = object() @@ -73,7 +77,7 @@ class Array: # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. @classmethod - def _new(cls, x, /): + def _new(cls, x, /, device=CPU_DEVICE): """ This is a private method for initializing the array API Array object. @@ -95,6 +99,9 @@ def _new(cls, x, /): ) obj._array = x obj._dtype = _dtype + if device is None: + device = CPU_DEVICE + obj._device = device return obj # Prevent Array() from working @@ -134,6 +141,8 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None will be present in other implementations. """ + if self._device != CPU_DEVICE: + raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") # copy keyword is new in 2.0.0; for older versions don't use it # retry without that keyword. if np.__version__[0] < '2': @@ -1154,8 +1163,11 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array: def to_device(self: Array, device: Device, /, stream: None = None) -> Array: if stream is not None: raise ValueError("The stream argument to to_device() is not supported") - if device == CPU_DEVICE: + if device == self._device: return self + elif isinstance(device, Device): + arr = np.asarray(self._array, copy=True) + return self.__class__._new(arr, device=device) raise ValueError(f"Unsupported device {device!r}") @property @@ -1169,7 +1181,7 @@ def dtype(self) -> Dtype: @property def device(self) -> Device: - return CPU_DEVICE + return self._device # Note: mT is new in array API spec (see matrix_transpose) @property diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 67ba67c..798e633 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -32,9 +32,9 @@ def _supports_buffer_protocol(obj): def _check_device(device): # _array_object imports in this file are inside the functions to avoid # circular imports - from ._array_object import CPU_DEVICE + from ._array_object import Device - if device not in [CPU_DEVICE, None]: + if device is not None and not isinstance(device, Device): raise ValueError(f"Unsupported device {device!r}") def asarray( @@ -79,7 +79,7 @@ def asarray( return Array._new(new_array) elif _supports_buffer_protocol(obj): # Buffer protocol will always support no-copy - return Array._new(np.array(obj, copy=copy, dtype=_np_dtype)) + return Array._new(np.array(obj, copy=copy, dtype=_np_dtype), device=device) else: # No-copy is unsupported for Python built-in types. raise ValueError("Unable to avoid copy while creating an array from given object.") @@ -89,13 +89,13 @@ def asarray( copy = False if isinstance(obj, Array): - return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype)) + return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device) if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)): # Give a better error message in this case. NumPy would convert this # to an object array. TODO: This won't handle large integers in lists. raise OverflowError("Integer out of bounds for array dtypes") res = np.array(obj, dtype=_np_dtype, copy=copy) - return Array._new(res) + return Array._new(res, device=device) def arange( @@ -119,7 +119,7 @@ def arange( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype)) + return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype), device=device) def empty( @@ -140,7 +140,7 @@ def empty( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.empty(shape, dtype=dtype)) + return Array._new(np.empty(shape, dtype=dtype), device=device) def empty_like( @@ -158,7 +158,7 @@ def empty_like( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.empty_like(x._array, dtype=dtype)) + return Array._new(np.empty_like(x._array, dtype=dtype), device=device) def eye( @@ -182,7 +182,7 @@ def eye( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype)) + return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype), device=device) _default = object() @@ -237,7 +237,7 @@ def full( # This will happen if the fill value is not something that NumPy # coerces to one of the acceptable dtypes. raise TypeError("Invalid input to full") - return Array._new(res) + return Array._new(res, device=device) def full_like( @@ -265,7 +265,7 @@ def full_like( # This will happen if the fill value is not something that NumPy # coerces to one of the acceptable dtypes. raise TypeError("Invalid input to full_like") - return Array._new(res) + return Array._new(res, device=device) def linspace( @@ -290,7 +290,7 @@ def linspace( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)) + return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint), device=device) def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: @@ -308,7 +308,7 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: raise ValueError("meshgrid inputs must all have the same dtype") return [ - Array._new(array) + Array._new(array, device=device) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing) ] @@ -331,7 +331,7 @@ def ones( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.ones(shape, dtype=dtype)) + return Array._new(np.ones(shape, dtype=dtype), device=device) def ones_like( @@ -349,7 +349,7 @@ def ones_like( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.ones_like(x._array, dtype=dtype)) + return Array._new(np.ones_like(x._array, dtype=dtype), device=device) def tril(x: Array, /, *, k: int = 0) -> Array: @@ -377,7 +377,7 @@ def triu(x: Array, /, *, k: int = 0) -> Array: if x.ndim < 2: # Note: Unlike np.triu, x must be at least 2-D raise ValueError("x must be at least 2-dimensional for triu") - return Array._new(np.triu(x._array, k=k)) + return Array._new(np.triu(x._array, k=k), device=device) def zeros( @@ -398,7 +398,7 @@ def zeros( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.zeros(shape, dtype=dtype)) + return Array._new(np.zeros(shape, dtype=dtype), device=device) def zeros_like( @@ -416,4 +416,4 @@ def zeros_like( if dtype is not None: dtype = dtype._np_dtype - return Array._new(np.zeros_like(x._array, dtype=dtype)) + return Array._new(np.zeros_like(x._array, dtype=dtype), device=device) diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index eb1b834..05a479c 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -27,7 +27,7 @@ Protocol, ) -from ._array_object import Array, _cpu_device +from ._array_object import Array, _device from ._dtypes import _DType _T_co = TypeVar("_T_co", covariant=True) @@ -37,7 +37,7 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -Device = _cpu_device +Device = _device Dtype = _DType diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index b0d4868..5146bba 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -319,7 +319,7 @@ def test_python_scalar_construtors(): def test_device_property(): a = ones((3, 4)) assert a.device == CPU_DEVICE - assert a.device != 'cpu' + assert not isinstance(a.device, str) assert all(equal(a.to_device(CPU_DEVICE), a)) assert_raises(ValueError, lambda: a.to_device('cpu')) From 36d15bbe24eeb0c597bd98dd248b1d9adbe21c40 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 13 Sep 2024 15:18:12 -0600 Subject: [PATCH 119/252] Fix __all__ not getting updated with reset_array_api_strict_flags() --- array_api_strict/_flags.py | 4 +++- array_api_strict/tests/test_flags.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 62acddf..46c0786 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -262,7 +262,9 @@ def reset_array_api_strict_flags(): BOOLEAN_INDEXING = True DATA_DEPENDENT_SHAPES = True ENABLED_EXTENSIONS = default_extensions - + array_api_strict.__all__[:] = sorted(set(ENABLED_EXTENSIONS) | + set(array_api_strict.__all__) - + set(default_extensions)) class ArrayAPIStrictFlags: """ diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 86ad8e2..76ca596 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -371,6 +371,15 @@ def test_disabled_extensions(): assert 'linalg' not in ns assert 'fft' not in ns + reset_array_api_strict_flags() + assert 'linalg' in xp.__all__ + assert 'fft' in xp.__all__ + xp.linalg # No error + xp.fft # No error + ns = {} + exec('from array_api_strict import *', ns) + assert 'linalg' in ns + assert 'fft' in ns def test_environment_variables(): # Test that the environment variables work as expected From b3d5214e453d12780897048b9732da1ca0c57b39 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 13 Sep 2024 15:20:47 -0600 Subject: [PATCH 120/252] Make 2023.12 the default version SciPy and others have been using it and haven't found any issues. Test suite support is still not 100% but is pretty strong at this point. This also splits some of the tests to avoid setting different versions of flags within the same test. --- README.md | 4 - array_api_strict/_flags.py | 8 +- array_api_strict/tests/test_array_object.py | 9 +- .../tests/test_elementwise_functions.py | 3 +- array_api_strict/tests/test_flags.py | 96 +++++++++---------- array_api_strict/tests/test_linalg.py | 27 +++--- .../tests/test_statistical_functions.py | 20 +++- docs/index.md | 16 +--- 8 files changed, 87 insertions(+), 96 deletions(-) diff --git a/README.md b/README.md index 8172237..0d52fec 100644 --- a/README.md +++ b/README.md @@ -15,8 +15,4 @@ libraries. Consuming library code should use the support the array API. Rather, it is intended to be used in the test suites of consuming libraries to test their array API usage. -array-api-strict currently supports the 2022.12 version of the standard. -2023.12 support is planned and is tracked by [this -issue](https://github.com/data-apis/array-api-strict/issues/25). - See the documentation for more details https://data-apis.org/array-api-strict/ diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 46c0786..c393ad9 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -24,7 +24,7 @@ "2023.12", ) -API_VERSION = default_version = "2022.12" +API_VERSION = default_version = "2023.12" BOOLEAN_INDEXING = True @@ -76,10 +76,6 @@ def set_array_api_strict_flags( Note that 2021.12 is supported, but currently gives the same thing as 2022.12 (except that the fft extension will be disabled). - 2023.12 support is experimental. Some features in 2023.12 may still be - missing, and it hasn't been fully tested. A future version of - array-api-strict will change the default version to 2023.12. - boolean_indexing : bool, optional Whether indexing by a boolean array is supported. This flag is enabled by default. Note that although boolean array indexing does result in @@ -142,8 +138,6 @@ def set_array_api_strict_flags( raise ValueError(f"Unsupported standard version {api_version!r}") if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2) - if api_version == "2023.12": - warnings.warn("The 2023.12 version of the array API specification is still preliminary. Some functions are not yet implemented, and it has not been fully tested.", stacklevel=2) API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index b0d4868..dad6696 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -406,16 +406,15 @@ def test_array_keys_use_private_array(): def test_array_namespace(): a = ones((3, 3)) assert a.__array_namespace__() == array_api_strict - assert array_api_strict.__array_api_version__ == "2022.12" + assert array_api_strict.__array_api_version__ == "2023.12" assert a.__array_namespace__(api_version=None) is array_api_strict - assert array_api_strict.__array_api_version__ == "2022.12" + assert array_api_strict.__array_api_version__ == "2023.12" assert a.__array_namespace__(api_version="2022.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2022.12" - with pytest.warns(UserWarning): - assert a.__array_namespace__(api_version="2023.12") is array_api_strict + assert a.__array_namespace__(api_version="2023.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2023.12" with pytest.warns(UserWarning): @@ -435,7 +434,7 @@ def test_iter(): @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) def dlpack_2023_12(api_version): - if api_version != '2022.12': + if api_version == '2021.12': with pytest.warns(UserWarning): set_array_api_strict_flags(api_version=api_version) else: diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index fa3405a..8f3ce7a 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -111,8 +111,7 @@ def _array_vals(): yield asarray(1.0, dtype=d) # Use the latest version of the standard so all functions are included - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version="2023.12") + set_array_api_strict_flags(api_version="2023.12") for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 76ca596..2603f35 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -18,21 +18,38 @@ import pytest -def test_flags(): - # Test defaults +def test_flag_defaults(): flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2022.12', + 'api_version': '2023.12', + 'boolean_indexing': True, + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } + + +def test_reset_flags(): + with pytest.warns(UserWarning): + set_array_api_strict_flags( + api_version='2021.12', + boolean_indexing=False, + data_dependent_shapes=False, + enabled_extensions=()) + reset_array_api_strict_flags() + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2023.12', 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), } - # Test setting flags + +def test_setting_flags(): set_array_api_strict_flags(data_dependent_shapes=False) flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2022.12', + 'api_version': '2023.12', 'boolean_indexing': True, 'data_dependent_shapes': False, 'enabled_extensions': ('linalg', 'fft'), @@ -40,11 +57,13 @@ def test_flags(): set_array_api_strict_flags(enabled_extensions=('fft',)) flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2022.12', + 'api_version': '2023.12', 'boolean_indexing': True, 'data_dependent_shapes': False, 'enabled_extensions': ('fft',), } + +def test_flags_api_version_2021_12(): # Make sure setting the version to 2021.12 disables fft and issues a # warning. with pytest.warns(UserWarning) as record: @@ -55,27 +74,23 @@ def test_flags(): assert flags == { 'api_version': '2021.12', 'boolean_indexing': True, - 'data_dependent_shapes': False, - 'enabled_extensions': (), + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg',), } - reset_array_api_strict_flags() - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2021.12') +def test_flags_api_version_2022_12(): + set_array_api_strict_flags(api_version='2022.12') flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2021.12', + 'api_version': '2022.12', 'boolean_indexing': True, 'data_dependent_shapes': True, - 'enabled_extensions': ('linalg',), + 'enabled_extensions': ('linalg', 'fft'), } - reset_array_api_strict_flags() - # 2023.12 should issue a warning - with pytest.warns(UserWarning) as record: - set_array_api_strict_flags(api_version='2023.12') - assert len(record) == 1 - assert '2023.12' in str(record[0].message) + +def test_flags_api_version_2023_12(): + set_array_api_strict_flags(api_version='2023.12') flags = get_array_api_strict_flags() assert flags == { 'api_version': '2023.12', @@ -84,6 +99,7 @@ def test_flags(): 'enabled_extensions': ('linalg', 'fft'), } +def test_setting_flags_invalid(): # Test setting flags with invalid values pytest.raises(ValueError, lambda: set_array_api_strict_flags(api_version='2020.12')) @@ -94,35 +110,15 @@ def test_flags(): api_version='2021.12', enabled_extensions=('linalg', 'fft'))) - # Test resetting flags - with pytest.warns(UserWarning): - set_array_api_strict_flags( - api_version='2021.12', - boolean_indexing=False, - data_dependent_shapes=False, - enabled_extensions=()) - reset_array_api_strict_flags() - flags = get_array_api_strict_flags() - assert flags == { - 'api_version': '2022.12', - 'boolean_indexing': True, - 'data_dependent_shapes': True, - 'enabled_extensions': ('linalg', 'fft'), - } - def test_api_version(): # Test defaults - assert xp.__array_api_version__ == '2022.12' + assert xp.__array_api_version__ == '2023.12' # Test setting the version - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2021.12') - assert xp.__array_api_version__ == '2021.12' + set_array_api_strict_flags(api_version='2022.12') + assert xp.__array_api_version__ == '2022.12' def test_data_dependent_shapes(): - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2023.12') # to enable repeat() - a = asarray([0, 0, 1, 2, 2]) mask = asarray([True, False, True, False, True]) repeats = asarray([1, 1, 2, 2, 2]) @@ -275,12 +271,16 @@ def test_fft(func_name): def test_api_version_2023_12(func_name): func = api_version_2023_12_examples[func_name] - # By default, these functions should error + # By default, these functions should not error + func() + + # In 2022.12, these functions should error + set_array_api_strict_flags(api_version='2022.12') pytest.raises(RuntimeError, func) - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2023.12') - func() + # Test the behavior gets updated properly + set_array_api_strict_flags(api_version='2023.12') + func() set_array_api_strict_flags(api_version='2022.12') pytest.raises(RuntimeError, func) @@ -387,9 +387,9 @@ def test_environment_variables(): # ARRAY_API_STRICT_API_VERSION ('''\ import array_api_strict as xp -assert xp.__array_api_version__ == '2022.12' +assert xp.__array_api_version__ == '2023.12' -assert xp.get_array_api_strict_flags()['api_version'] == '2022.12' +assert xp.get_array_api_strict_flags()['api_version'] == '2023.12' ''', {}), *[ diff --git a/array_api_strict/tests/test_linalg.py b/array_api_strict/tests/test_linalg.py index 5e6cda2..04023bc 100644 --- a/array_api_strict/tests/test_linalg.py +++ b/array_api_strict/tests/test_linalg.py @@ -8,14 +8,17 @@ # Technically this is linear_algebra, not linalg, but it's simpler to keep # both of these tests together -def test_vecdot_2023_12(): - # Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >= - # 0 behavior (which is primarily kept for backwards compatibility). + + +# Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >= +# 0 behavior (which is primarily kept for backwards compatibility). +def test_vecdot_2022_12(): + # 2022.12 behavior, which is to apply axis >= 0 after broadcasting + set_array_api_strict_flags(api_version='2022.12') a = xp.ones((2, 3, 4, 5)) b = xp.ones(( 3, 4, 1)) - # 2022.12 behavior, which is to apply axis >= 0 after broadcasting pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0)) assert xp.linalg.vecdot(a, b, axis=1).shape == (2, 4, 5) assert xp.linalg.vecdot(a, b, axis=2).shape == (2, 3, 5) @@ -34,10 +37,13 @@ def test_vecdot_2023_12(): assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5) assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5) +def test_vecdot_2023_12(): # 2023.12 behavior, which is to only allow axis < 0 and axis >= # min(x1.ndim, x2.ndim), which is unambiguous - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2023.12') + set_array_api_strict_flags(api_version='2023.12') + + a = xp.ones((2, 3, 4, 5)) + b = xp.ones(( 3, 4, 1)) pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0)) pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=1)) @@ -56,7 +62,7 @@ def test_cross(api_version): # This test tests everything that should be the same across all supported # API versions. - if api_version != '2022.12': + if api_version == '2021.12': with pytest.warns(UserWarning): set_array_api_strict_flags(api_version=api_version) else: @@ -88,7 +94,7 @@ def test_cross_2022_12(api_version): # backwards compatibility. Note that unlike vecdot, array_api_strict # cross() never implemented the "after broadcasting" axis behavior, but # just reused NumPy cross(), which applies axes before broadcasting. - if api_version != '2022.12': + if api_version == '2021.12': with pytest.warns(UserWarning): set_array_api_strict_flags(api_version=api_version) else: @@ -104,11 +110,6 @@ def test_cross_2022_12(api_version): assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5) def test_cross_2023_12(): - # 2023.12 behavior, which is to only allow axis < 0 and axis >= - # min(x1.ndim, x2.ndim), which is unambiguous - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2023.12') - a = xp.ones((3, 2, 4, 5)) b = xp.ones((3, 2, 4, 1)) pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0)) diff --git a/array_api_strict/tests/test_statistical_functions.py b/array_api_strict/tests/test_statistical_functions.py index 61e848c..7f2a457 100644 --- a/array_api_strict/tests/test_statistical_functions.py +++ b/array_api_strict/tests/test_statistical_functions.py @@ -4,10 +4,12 @@ import array_api_strict as xp +# sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes +# with dtype=None @pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace']) -def test_sum_prod_trace_2023_12(func_name): - # sum, prod, and trace were changed in 2023.12 to not upcast floating-point dtypes - # with dtype=None +def test_sum_prod_trace_2022_12(func_name): + set_array_api_strict_flags(api_version='2022.12') + if func_name == 'trace': func = getattr(xp.linalg, func_name) else: @@ -21,8 +23,16 @@ def test_sum_prod_trace_2023_12(func_name): assert func(a_complex).dtype == xp.complex128 assert func(a_int).dtype == xp.int64 - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2023.12') +@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace']) +def test_sum_prod_trace_2023_12(func_name): + a_real = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.float32) + a_complex = xp.asarray([[1., 2.], [3., 4.]], dtype=xp.complex64) + a_int = xp.asarray([[1, 2], [3, 4]], dtype=xp.int32) + + if func_name == 'trace': + func = getattr(xp.linalg, func_name) + else: + func = getattr(xp, func_name) assert func(a_real).dtype == xp.float32 assert func(a_complex).dtype == xp.complex64 diff --git a/docs/index.md b/docs/index.md index a14fbcb..12aadbb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -16,10 +16,10 @@ support the array API. Rather, it is intended to be used in the test suites of consuming libraries to test their array API usage. array-api-strict currently supports the -[2022.12](https://data-apis.org/array-api/latest/changelog.html#v2022-12) -version of the standard. Experimental -[2023.12](https://data-apis.org/array-api/latest/changelog.html#v2023-12) -support is implemented, [but must be enabled with a +[2023.12](https://data-apis.org/array-api/latest/changelog.html#v2022-12) +version of the standard. +[2022.12](https://data-apis.org/array-api/latest/changelog.html#v2023-12) +support is also implemented, [and can be enabled with a flag](array-api-strict-flags). ## Install @@ -176,14 +176,6 @@ issue, but this hasn't necessarily been tested thoroughly. this deviation may be tested with type checking. This [behavior may improve in the future](https://github.com/data-apis/array-api-strict/issues/6). -5. array-api-strict currently uses the 2022.12 version of the array API - standard by default. Support for 2023.12 is implemented but is still - experimental and not fully tested. It can be enabled with - {func}`array_api_strict.set_array_api_strict_flags(api_version='2023.12') - ` or by setting the - environment variable {envvar}`ARRAY_API_STRICT_API_VERSION=2023.12 - `. - (numpy.array_api)= ## Relationship to `numpy.array_api` From 2aae491a421bc8aa74b35d37d3006cf20d479305 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 13 Sep 2024 15:23:54 -0600 Subject: [PATCH 121/252] Remove unused import --- array_api_strict/tests/test_elementwise_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 8f3ce7a..870361e 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -17,7 +17,6 @@ ) from .._flags import set_array_api_strict_flags -import pytest def nargs(func): return len(getfullargspec(func).args) From 325b9d0e87c5cd137910d2b3726ba204d5a65456 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Fri, 27 Sep 2024 10:38:02 +0200 Subject: [PATCH 122/252] Loop device through elementwise functions --- array_api_strict/_array_object.py | 73 +++++--- array_api_strict/_elementwise_functions.py | 196 ++++++++++++++------- array_api_strict/_info.py | 4 +- 3 files changed, 185 insertions(+), 88 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 72efd25..be19185 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -45,10 +45,12 @@ class Device: def __init__(self, device="CPU_DEVICE"): + if device not in ("CPU_DEVICE", "device1", "device2"): + raise ValueError(f"The device '{device}' is not a valid choice.") self._device = device def __repr__(self): - return f"Device('{self._device}')" + return f"array_api_strict.Device('{self._device}')" def __eq__(self, other): return self._device == other._device @@ -77,7 +79,7 @@ class Array: # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. @classmethod - def _new(cls, x, /, device=CPU_DEVICE): + def _new(cls, x, /, device=None): """ This is a private method for initializing the array API Array object. @@ -123,7 +125,11 @@ def __repr__(self: Array, /) -> str: """ Performs the operation __repr__. """ - suffix = f", dtype={self.dtype})" + suffix = f", dtype={self.dtype}" + if self.device != CPU_DEVICE: + suffix += f", device={self.device})" + else: + suffix += ")" if 0 in self.shape: prefix = "empty(" mid = str(self.shape) @@ -202,6 +208,15 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor return other + def _check_device(self, other): + """Check that other is on a device compatible with the current array""" + if isinstance(other, (int, complex, float, bool)): + return other + elif isinstance(other, Array): + if self.device != other.device: + raise RuntimeError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") + return other + # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): """ @@ -477,23 +492,25 @@ def __add__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __add__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__add__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__add__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __and__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__and__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__and__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __array_namespace__( self: Array, /, *, api_version: Optional[str] = None @@ -577,6 +594,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ Performs the operation __eq__. """ + other = self._check_device(other) # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. other = self._check_allowed_dtypes(other, "all", "__eq__") @@ -584,7 +602,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: return other self, other = self._normalize_two_args(self, other) res = self._array.__eq__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __float__(self: Array, /) -> float: """ @@ -602,23 +620,25 @@ def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __floordiv__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__floordiv__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ge__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__ge__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__ge__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __getitem__( self: Array, @@ -634,6 +654,7 @@ def __getitem__( """ Performs the operation __getitem__. """ + # XXX Does key have to be on the same device? Is there an exception for CPU_DEVICE? # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index self._validate_index(key) @@ -641,12 +662,13 @@ def __getitem__( # Indexing self._array with array_api_strict arrays can be erroneous key = key._array res = self._array.__getitem__(key) - return self._new(res) + return self._new(res, device=self.device) def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __gt__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__gt__") if other is NotImplemented: return other @@ -680,7 +702,7 @@ def __invert__(self: Array, /) -> Array: if self.dtype not in _integer_or_boolean_dtypes: raise TypeError("Only integer or boolean dtypes are allowed in __invert__") res = self._array.__invert__() - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __iter__(self: Array, /): """ @@ -695,85 +717,92 @@ def __iter__(self: Array, /): # define __iter__, but it doesn't disallow it. The default Python # behavior is to implement iter as a[0], a[1], ... when __getitem__ is # implemented, which implies iteration on 1-D arrays. - return (Array._new(i) for i in self._array) + return (Array._new(i, device=self.device) for i in self._array) def __le__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __le__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__le__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__le__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __lshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __lshift__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__lshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__lshift__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __lt__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__lt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__lt__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __matmul__(self: Array, other: Array, /) -> Array: """ Performs the operation __matmul__. """ + other = self._check_device(other) # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._check_allowed_dtypes(other, "numeric", "__matmul__") if other is NotImplemented: return other res = self._array.__matmul__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __mod__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __mod__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__mod__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__mod__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __mul__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __mul__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__mul__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__mul__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ Performs the operation __ne__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__ne__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __neg__(self: Array, /) -> Array: """ @@ -782,18 +811,19 @@ def __neg__(self: Array, /) -> Array: if self.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in __neg__") res = self._array.__neg__() - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __or__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __or__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__or__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__or__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __pos__(self: Array, /) -> Array: """ @@ -802,7 +832,7 @@ def __pos__(self: Array, /) -> Array: if self.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in __pos__") res = self._array.__pos__() - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: """ @@ -810,6 +840,7 @@ def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: """ from ._elementwise_functions import pow + other = self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__pow__") if other is NotImplemented: return other diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index b39bd86..4035841 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -28,7 +28,7 @@ def abs(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in abs") - return Array._new(np.abs(x._array)) + return Array._new(np.abs(x._array), device=x.device) # Note: the function name is different here @@ -40,7 +40,7 @@ def acos(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in acos") - return Array._new(np.arccos(x._array)) + return Array._new(np.arccos(x._array), device=x.device) # Note: the function name is different here @@ -52,7 +52,7 @@ def acosh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in acosh") - return Array._new(np.arccosh(x._array)) + return Array._new(np.arccosh(x._array), device=x.device) def add(x1: Array, x2: Array, /) -> Array: @@ -61,12 +61,15 @@ def add(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in add") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.add(x1._array, x2._array)) + return Array._new(np.add(x1._array, x2._array), device=x1.device) # Note: the function name is different here @@ -78,7 +81,7 @@ def asin(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in asin") - return Array._new(np.arcsin(x._array)) + return Array._new(np.arcsin(x._array), device=x.device) # Note: the function name is different here @@ -90,7 +93,7 @@ def asinh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in asinh") - return Array._new(np.arcsinh(x._array)) + return Array._new(np.arcsinh(x._array), device=x.device) # Note: the function name is different here @@ -102,7 +105,7 @@ def atan(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in atan") - return Array._new(np.arctan(x._array)) + return Array._new(np.arctan(x._array), device=x.device) # Note: the function name is different here @@ -112,12 +115,14 @@ def atan2(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in atan2") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.arctan2(x1._array, x2._array)) + return Array._new(np.arctan2(x1._array, x2._array), device=x1.device) # Note: the function name is different here @@ -129,7 +134,7 @@ def atanh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in atanh") - return Array._new(np.arctanh(x._array)) + return Array._new(np.arctanh(x._array), device=x.device) def bitwise_and(x1: Array, x2: Array, /) -> Array: @@ -138,6 +143,9 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if ( x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes @@ -146,7 +154,7 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array: # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_and(x1._array, x2._array)) + return Array._new(np.bitwise_and(x1._array, x2._array), device=x1.device) # Note: the function name is different here @@ -156,6 +164,9 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError("Only integer dtypes are allowed in bitwise_left_shift") # Call result type here just to raise on disallowed type combinations @@ -164,7 +175,7 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: # Note: bitwise_left_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") - return Array._new(np.left_shift(x1._array, x2._array)) + return Array._new(np.left_shift(x1._array, x2._array), device=x1.device) # Note: the function name is different here @@ -176,7 +187,7 @@ def bitwise_invert(x: Array, /) -> Array: """ if x.dtype not in _integer_or_boolean_dtypes: raise TypeError("Only integer or boolean dtypes are allowed in bitwise_invert") - return Array._new(np.invert(x._array)) + return Array._new(np.invert(x._array), device=x.device) def bitwise_or(x1: Array, x2: Array, /) -> Array: @@ -185,6 +196,9 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if ( x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes @@ -193,7 +207,7 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array: # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_or(x1._array, x2._array)) + return Array._new(np.bitwise_or(x1._array, x2._array), device=x1.device) # Note: the function name is different here @@ -203,6 +217,9 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError("Only integer dtypes are allowed in bitwise_right_shift") # Call result type here just to raise on disallowed type combinations @@ -211,7 +228,7 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: # Note: bitwise_right_shift is only defined for x2 nonnegative. if np.any(x2._array < 0): raise ValueError("bitwise_right_shift(x1, x2) is only defined for x2 >= 0") - return Array._new(np.right_shift(x1._array, x2._array)) + return Array._new(np.right_shift(x1._array, x2._array), device=x1.device) def bitwise_xor(x1: Array, x2: Array, /) -> Array: @@ -220,6 +237,9 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if ( x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes @@ -228,7 +248,7 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array: # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_xor(x1._array, x2._array)) + return Array._new(np.bitwise_xor(x1._array, x2._array), device=x1.device) def ceil(x: Array, /) -> Array: @@ -242,7 +262,7 @@ def ceil(x: Array, /) -> Array: if x.dtype in _integer_dtypes: # Note: The return dtype of ceil is the same as the input return x - return Array._new(np.ceil(x._array)) + return Array._new(np.ceil(x._array), device=x.device) # WARNING: This function is not yet tested by the array-api-tests test suite. @@ -259,6 +279,11 @@ def clip( See its docstring for more information. """ + if isinstance(min, Array) and x.device != min.device: + raise RuntimeError(f"Arrays from two different devices ({x.device} and {min.device}) can not be combined.") + if isinstance(max, Array) and x.device != max.device: + raise RuntimeError(f"Arrays from two different devices ({x.device} and {max.device}) can not be combined.") + if (x.dtype not in _real_numeric_dtypes or isinstance(min, Array) and min.dtype not in _real_numeric_dtypes or isinstance(max, Array) and max.dtype not in _real_numeric_dtypes): @@ -307,7 +332,7 @@ def clip( # TODO: I'm not completely sure this always gives the correct thing # for integer dtypes. See https://github.com/numpy/numpy/issues/24976 result = result.astype(x.dtype._np_dtype) - return Array._new(result) + return Array._new(result, device=x.device) def conj(x: Array, /) -> Array: """ @@ -317,7 +342,7 @@ def conj(x: Array, /) -> Array: """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in conj") - return Array._new(np.conj(x)) + return Array._new(np.conj(x), device=x.device) @requires_api_version('2023.12') def copysign(x1: Array, x2: Array, /) -> Array: @@ -326,12 +351,15 @@ def copysign(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in copysign") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.copysign(x1._array, x2._array)) + return Array._new(np.copysign(x1._array, x2._array), device=x1.device) def cos(x: Array, /) -> Array: """ @@ -341,7 +369,7 @@ def cos(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in cos") - return Array._new(np.cos(x._array)) + return Array._new(np.cos(x._array), device=x.device) def cosh(x: Array, /) -> Array: @@ -352,7 +380,7 @@ def cosh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in cosh") - return Array._new(np.cosh(x._array)) + return Array._new(np.cosh(x._array), device=x.device) def divide(x1: Array, x2: Array, /) -> Array: @@ -361,12 +389,14 @@ def divide(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in divide") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.divide(x1._array, x2._array)) + return Array._new(np.divide(x1._array, x2._array), device=x1.device) def equal(x1: Array, x2: Array, /) -> Array: @@ -375,10 +405,12 @@ def equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.equal(x1._array, x2._array)) + return Array._new(np.equal(x1._array, x2._array), device=x1.device) def exp(x: Array, /) -> Array: @@ -389,7 +421,7 @@ def exp(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in exp") - return Array._new(np.exp(x._array)) + return Array._new(np.exp(x._array), device=x.device) def expm1(x: Array, /) -> Array: @@ -400,7 +432,7 @@ def expm1(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in expm1") - return Array._new(np.expm1(x._array)) + return Array._new(np.expm1(x._array), device=x.device) def floor(x: Array, /) -> Array: @@ -414,7 +446,7 @@ def floor(x: Array, /) -> Array: if x.dtype in _integer_dtypes: # Note: The return dtype of floor is the same as the input return x - return Array._new(np.floor(x._array)) + return Array._new(np.floor(x._array), device=x.device) def floor_divide(x1: Array, x2: Array, /) -> Array: @@ -423,12 +455,14 @@ def floor_divide(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in floor_divide") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.floor_divide(x1._array, x2._array)) + return Array._new(np.floor_divide(x1._array, x2._array), device=x1.device) def greater(x1: Array, x2: Array, /) -> Array: @@ -437,12 +471,14 @@ def greater(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.greater(x1._array, x2._array)) + return Array._new(np.greater(x1._array, x2._array), device=x1.device) def greater_equal(x1: Array, x2: Array, /) -> Array: @@ -451,12 +487,14 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater_equal") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.greater_equal(x1._array, x2._array)) + return Array._new(np.greater_equal(x1._array, x2._array), device=x1.device) @requires_api_version('2023.12') def hypot(x1: Array, x2: Array, /) -> Array: @@ -465,12 +503,14 @@ def hypot(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in hypot") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.hypot(x1._array, x2._array)) + return Array._new(np.hypot(x1._array, x2._array), device=x1.device) def imag(x: Array, /) -> Array: """ @@ -480,7 +520,7 @@ def imag(x: Array, /) -> Array: """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in imag") - return Array._new(np.imag(x)) + return Array._new(np.imag(x), device=x.device) def isfinite(x: Array, /) -> Array: @@ -491,7 +531,7 @@ def isfinite(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isfinite") - return Array._new(np.isfinite(x._array)) + return Array._new(np.isfinite(x._array), device=x.device) def isinf(x: Array, /) -> Array: @@ -502,7 +542,7 @@ def isinf(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isinf") - return Array._new(np.isinf(x._array)) + return Array._new(np.isinf(x._array), device=x.device) def isnan(x: Array, /) -> Array: @@ -513,7 +553,7 @@ def isnan(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isnan") - return Array._new(np.isnan(x._array)) + return Array._new(np.isnan(x._array), device=x.device) def less(x1: Array, x2: Array, /) -> Array: @@ -522,12 +562,14 @@ def less(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.less(x1._array, x2._array)) + return Array._new(np.less(x1._array, x2._array), device=x1.device) def less_equal(x1: Array, x2: Array, /) -> Array: @@ -536,12 +578,14 @@ def less_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less_equal") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.less_equal(x1._array, x2._array)) + return Array._new(np.less_equal(x1._array, x2._array), device=x1.device) def log(x: Array, /) -> Array: @@ -552,7 +596,7 @@ def log(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in log") - return Array._new(np.log(x._array)) + return Array._new(np.log(x._array), device=x.device) def log1p(x: Array, /) -> Array: @@ -563,7 +607,7 @@ def log1p(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in log1p") - return Array._new(np.log1p(x._array)) + return Array._new(np.log1p(x._array), device=x.device) def log2(x: Array, /) -> Array: @@ -574,7 +618,7 @@ def log2(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in log2") - return Array._new(np.log2(x._array)) + return Array._new(np.log2(x._array), device=x.device) def log10(x: Array, /) -> Array: @@ -585,7 +629,7 @@ def log10(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in log10") - return Array._new(np.log10(x._array)) + return Array._new(np.log10(x._array), device=x.device) def logaddexp(x1: Array, x2: Array) -> Array: @@ -594,12 +638,14 @@ def logaddexp(x1: Array, x2: Array) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in logaddexp") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logaddexp(x1._array, x2._array)) + return Array._new(np.logaddexp(x1._array, x2._array), device=x1.device) def logical_and(x1: Array, x2: Array, /) -> Array: @@ -608,12 +654,14 @@ def logical_and(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_and") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_and(x1._array, x2._array)) + return Array._new(np.logical_and(x1._array, x2._array), device=x1.device) def logical_not(x: Array, /) -> Array: @@ -624,7 +672,7 @@ def logical_not(x: Array, /) -> Array: """ if x.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_not") - return Array._new(np.logical_not(x._array)) + return Array._new(np.logical_not(x._array), device=x.device) def logical_or(x1: Array, x2: Array, /) -> Array: @@ -633,12 +681,14 @@ def logical_or(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_or") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_or(x1._array, x2._array)) + return Array._new(np.logical_or(x1._array, x2._array), device=x1.device) def logical_xor(x1: Array, x2: Array, /) -> Array: @@ -647,12 +697,14 @@ def logical_xor(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_xor") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_xor(x1._array, x2._array)) + return Array._new(np.logical_xor(x1._array, x2._array), device=x1.device) @requires_api_version('2023.12') def maximum(x1: Array, x2: Array, /) -> Array: @@ -661,6 +713,8 @@ def maximum(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in maximum") # Call result type here just to raise on disallowed type combinations @@ -668,7 +722,7 @@ def maximum(x1: Array, x2: Array, /) -> Array: x1, x2 = Array._normalize_two_args(x1, x2) # TODO: maximum(-0., 0.) is unspecified. Should we issue a warning/error # in that case? - return Array._new(np.maximum(x1._array, x2._array)) + return Array._new(np.maximum(x1._array, x2._array), device=x1.device) @requires_api_version('2023.12') def minimum(x1: Array, x2: Array, /) -> Array: @@ -677,12 +731,14 @@ def minimum(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in minimum") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.minimum(x1._array, x2._array)) + return Array._new(np.minimum(x1._array, x2._array), device=x1.device) def multiply(x1: Array, x2: Array, /) -> Array: """ @@ -690,12 +746,14 @@ def multiply(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in multiply") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.multiply(x1._array, x2._array)) + return Array._new(np.multiply(x1._array, x2._array), device=x1.device) def negative(x: Array, /) -> Array: @@ -706,7 +764,7 @@ def negative(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in negative") - return Array._new(np.negative(x._array)) + return Array._new(np.negative(x._array), device=x.device) def not_equal(x1: Array, x2: Array, /) -> Array: @@ -715,10 +773,12 @@ def not_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.not_equal(x1._array, x2._array)) + return Array._new(np.not_equal(x1._array, x2._array), device=x1.device) def positive(x: Array, /) -> Array: @@ -729,7 +789,7 @@ def positive(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in positive") - return Array._new(np.positive(x._array)) + return Array._new(np.positive(x._array), device=x.device) # Note: the function name is different here @@ -739,12 +799,14 @@ def pow(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in pow") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.power(x1._array, x2._array)) + return Array._new(np.power(x1._array, x2._array), device=x1.device) def real(x: Array, /) -> Array: @@ -755,7 +817,7 @@ def real(x: Array, /) -> Array: """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in real") - return Array._new(np.real(x)) + return Array._new(np.real(x), device=x.device) def remainder(x1: Array, x2: Array, /) -> Array: @@ -764,12 +826,14 @@ def remainder(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in remainder") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.remainder(x1._array, x2._array)) + return Array._new(np.remainder(x1._array, x2._array), device=x1.device) def round(x: Array, /) -> Array: @@ -780,7 +844,7 @@ def round(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in round") - return Array._new(np.round(x._array)) + return Array._new(np.round(x._array), device=x.device) def sign(x: Array, /) -> Array: @@ -791,7 +855,7 @@ def sign(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sign") - return Array._new(np.sign(x._array)) + return Array._new(np.sign(x._array), device=x.device) @requires_api_version('2023.12') @@ -803,7 +867,7 @@ def signbit(x: Array, /) -> Array: """ if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in signbit") - return Array._new(np.signbit(x._array)) + return Array._new(np.signbit(x._array), device=x.device) def sin(x: Array, /) -> Array: @@ -814,7 +878,7 @@ def sin(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in sin") - return Array._new(np.sin(x._array)) + return Array._new(np.sin(x._array), device=x.device) def sinh(x: Array, /) -> Array: @@ -825,7 +889,7 @@ def sinh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in sinh") - return Array._new(np.sinh(x._array)) + return Array._new(np.sinh(x._array), device=x.device) def square(x: Array, /) -> Array: @@ -836,7 +900,7 @@ def square(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in square") - return Array._new(np.square(x._array)) + return Array._new(np.square(x._array), device=x.device) def sqrt(x: Array, /) -> Array: @@ -847,7 +911,7 @@ def sqrt(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in sqrt") - return Array._new(np.sqrt(x._array)) + return Array._new(np.sqrt(x._array), device=x.device) def subtract(x1: Array, x2: Array, /) -> Array: @@ -856,12 +920,14 @@ def subtract(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in subtract") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.subtract(x1._array, x2._array)) + return Array._new(np.subtract(x1._array, x2._array), device=x1.device) def tan(x: Array, /) -> Array: @@ -872,7 +938,7 @@ def tan(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in tan") - return Array._new(np.tan(x._array)) + return Array._new(np.tan(x._array), device=x.device) def tanh(x: Array, /) -> Array: @@ -883,7 +949,7 @@ def tanh(x: Array, /) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in tanh") - return Array._new(np.tanh(x._array)) + return Array._new(np.tanh(x._array), device=x.device) def trunc(x: Array, /) -> Array: @@ -897,4 +963,4 @@ def trunc(x: Array, /) -> Array: if x.dtype in _integer_dtypes: # Note: The return dtype of trunc is the same as the input return x - return Array._new(np.trunc(x._array)) + return Array._new(np.trunc(x._array), device=x.device) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index ab5447a..cfcff8b 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -6,7 +6,7 @@ from typing import Optional, Union, Tuple, List from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info -from ._array_object import CPU_DEVICE +from ._array_object import CPU_DEVICE, Device from ._flags import get_array_api_strict_flags, requires_api_version from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 @@ -121,7 +121,7 @@ def dtypes( @requires_api_version('2023.12') def devices() -> List[device]: - return [CPU_DEVICE] + return [CPU_DEVICE, Device("device1"), Device("device2")] __all__ = [ "capabilities", From 6cc7bac906d0d2daa258d9d854925c5760b66065 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Fri, 27 Sep 2024 11:13:38 +0200 Subject: [PATCH 123/252] Define __hash__ --- array_api_strict/_array_object.py | 6 ++++++ array_api_strict/_creation_functions.py | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index be19185..b3e5dbc 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -53,8 +53,14 @@ def __repr__(self): return f"array_api_strict.Device('{self._device}')" def __eq__(self, other): + if not isinstance(other, Device): + return False return self._device == other._device + def __hash__(self): + return hash(("Device", self._device)) + + CPU_DEVICE = Device() _default = object() diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 798e633..3b22fd1 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -307,8 +307,11 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: if len({a.dtype for a in arrays}) > 1: raise ValueError("meshgrid inputs must all have the same dtype") + if len({a.device for a in arrays}) > 1: + raise ValueError("meshgrid inputs must all be on the same device") + return [ - Array._new(array, device=device) + Array._new(array, device=array.device) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing) ] From bca670acd6fe683edacc9cbaf245df012ddd6b5d Mon Sep 17 00:00:00 2001 From: Tim Head Date: Fri, 27 Sep 2024 13:55:22 +0200 Subject: [PATCH 124/252] More device pass through --- array_api_strict/_creation_functions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 3b22fd1..6291fc5 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -349,6 +349,8 @@ def ones_like( _check_valid_dtype(dtype) _check_device(device) + if device is None: + device = x.device if dtype is not None: dtype = dtype._np_dtype @@ -366,7 +368,7 @@ def tril(x: Array, /, *, k: int = 0) -> Array: if x.ndim < 2: # Note: Unlike np.tril, x must be at least 2-D raise ValueError("x must be at least 2-dimensional for tril") - return Array._new(np.tril(x._array, k=k)) + return Array._new(np.tril(x._array, k=k), device=x.device) def triu(x: Array, /, *, k: int = 0) -> Array: @@ -380,7 +382,7 @@ def triu(x: Array, /, *, k: int = 0) -> Array: if x.ndim < 2: # Note: Unlike np.triu, x must be at least 2-D raise ValueError("x must be at least 2-dimensional for triu") - return Array._new(np.triu(x._array, k=k), device=device) + return Array._new(np.triu(x._array, k=k), device=x.device) def zeros( @@ -416,6 +418,8 @@ def zeros_like( _check_valid_dtype(dtype) _check_device(device) + if device is None: + device = x.device if dtype is not None: dtype = dtype._np_dtype From 426609f87d96c028dddefb9251af1936c1700c4b Mon Sep 17 00:00:00 2001 From: Tim Head Date: Fri, 27 Sep 2024 14:03:03 +0200 Subject: [PATCH 125/252] Fix meshgrid --- array_api_strict/_creation_functions.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 6291fc5..dba54cf 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -310,8 +310,14 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: if len({a.device for a in arrays}) > 1: raise ValueError("meshgrid inputs must all be on the same device") + # arrays is allowed to be empty + if arrays: + device = arrays[0].device + else: + device = None + return [ - Array._new(array, device=array.device) + Array._new(array, device=device) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing) ] From 727072f60044eefca50faa94658a61113e929a1c Mon Sep 17 00:00:00 2001 From: Tim Head Date: Wed, 2 Oct 2024 19:18:25 +0200 Subject: [PATCH 126/252] Add testing and small typo fixes --- array_api_strict/_dtypes.py | 2 +- array_api_strict/_elementwise_functions.py | 4 +- .../tests/test_elementwise_functions.py | 63 +++++++++++++++++-- 3 files changed, 61 insertions(+), 8 deletions(-) diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index a91454f..b51ed92 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -121,7 +121,7 @@ def __hash__(self): "integer": _integer_dtypes, "integer or boolean": _integer_or_boolean_dtypes, "boolean": _boolean_dtypes, - "real floating-point": _floating_dtypes, + "real floating-point": _real_floating_dtypes, "complex floating-point": _complex_floating_dtypes, "floating-point": _floating_dtypes, } diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 4035841..74109ff 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -354,7 +354,7 @@ def copysign(x1: Array, x2: Array, /) -> Array: if x1.device != x2.device: raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real numeric dtypes are allowed in copysign") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) @@ -632,7 +632,7 @@ def log10(x: Array, /) -> Array: return Array._new(np.log10(x._array), device=x.device) -def logaddexp(x1: Array, x2: Array) -> Array: +def logaddexp(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logaddexp `. diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index fa3405a..1be9fbc 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,4 +1,4 @@ -from inspect import getfullargspec, getmodule +from inspect import signature, getmodule from numpy.testing import assert_raises @@ -19,8 +19,16 @@ import pytest +import array_api_strict + + def nargs(func): - return len(getfullargspec(func).args) + """Count number of 'array' arguments a function takes.""" + positional_only = 0 + for param in signature(func).parameters.values(): + if param.kind == param.POSITIONAL_ONLY: + positional_only += 1 + return positional_only elementwise_function_input_types = { @@ -91,12 +99,57 @@ def nargs(func): "trunc": "real numeric", } + +def test_nargs(): + # Explicitly check number of arguments for a few functions + assert nargs(array_api_strict.logaddexp) == 2 + assert nargs(array_api_strict.atan2) == 2 + assert nargs(array_api_strict.clip) == 1 + + # All elementwise functions take one or two array arguments + # if not, it is probably a bug in `nargs` or the definition + # of the function (missing trailing `, /`). + for func_name in elementwise_function_input_types: + func = getattr(_elementwise_functions, func_name) + assert nargs(func) in (1, 2) + + def test_missing_functions(): # Ensure the above dictionary is complete. import array_api_strict._elementwise_functions as mod mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod] assert set(mod_funcs) == set(elementwise_function_input_types) + +def test_function_device_persists(): + # Test that the device of the input and output array are the same + def _array_vals(dtypes): + for d in dtypes: + yield asarray(1., dtype=d) + + # Use the latest version of the standard so all functions are included + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2023.12") + + for func_name, types in elementwise_function_input_types.items(): + dtypes = _dtype_categories[types] + func = getattr(_elementwise_functions, func_name) + print(f"{func_name=} {nargs(func)=} {types=} {dtypes=}") + + for x in _array_vals(dtypes): + if nargs(func) == 2: + # This way we don't have to deal with incompatible + # types of the two arguments. + r = func(x, x) + assert r.device == x.device + + else: + if func_name == "atanh": + x -= 0.1 + r = func(x) + assert r.device == x.device + + def test_function_types(): # Test that every function accepts only the required input types. We only # test the negative cases here (error). The positive cases are tested in @@ -130,12 +183,12 @@ def _array_vals(): or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes ): - assert_raises(TypeError, lambda: func(x, y)) + assert_raises(TypeError, func, x, y) if x.dtype not in dtypes or y.dtype not in dtypes: - assert_raises(TypeError, lambda: func(x, y)) + assert_raises(TypeError, func, x, y) else: if x.dtype not in dtypes: - assert_raises(TypeError, lambda: func(x)) + assert_raises(TypeError, func, x) def test_bitwise_shift_error(): From 23f390e73f909cad8f33b24ea98b757b13e19395 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 3 Oct 2024 09:50:55 +0200 Subject: [PATCH 127/252] Add a comment about atanh special casing --- array_api_strict/tests/test_elementwise_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 1be9fbc..fd8c3ae 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -134,7 +134,6 @@ def _array_vals(dtypes): for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] func = getattr(_elementwise_functions, func_name) - print(f"{func_name=} {nargs(func)=} {types=} {dtypes=}") for x in _array_vals(dtypes): if nargs(func) == 2: @@ -144,6 +143,8 @@ def _array_vals(dtypes): assert r.device == x.device else: + # `atanh` needs a slightly different input value from + # everyone else if func_name == "atanh": x -= 0.1 r = func(x) From 405b7e7b4131a3fa7cbda807e3b66ed9dcf4938e Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 3 Oct 2024 09:58:04 +0200 Subject: [PATCH 128/252] Add conversion to NumPy test The default device should continue to convert, but other arrays from other devices should error. --- array_api_strict/tests/test_array_object.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 5146bba..19382d5 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -349,6 +349,17 @@ def test___array__(): assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64))) assert b.dtype == np.float64 +def test_array_conversion(): + # Check that arrays on the CPU device can be converted to NumPy + # but arrays on other devices can't + a = ones((2, 3)) + np.asarray(a) + + for device in ("device1", "device2"): + a = ones((2, 3), device=array_api_strict.Device(device)) + with pytest.raises(RuntimeError, match="Can not convert array"): + np.asarray(a) + def test_allow_newaxis(): a = ones(5) indexed_a = a[None, :] From 03e1ae7de96bb3e429331f114017a22d9d5cd95d Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 3 Oct 2024 10:16:08 +0200 Subject: [PATCH 129/252] Add multi-device support to sorting functions --- array_api_strict/_sorting_functions.py | 4 ++-- array_api_strict/tests/test_sorting_functions.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_sorting_functions.py b/array_api_strict/_sorting_functions.py index 9b8cb04..765bd9e 100644 --- a/array_api_strict/_sorting_functions.py +++ b/array_api_strict/_sorting_functions.py @@ -33,7 +33,7 @@ def argsort( normalised_axis = axis if axis >= 0 else x.ndim + axis max_i = x.shape[normalised_axis] - 1 res = max_i - res - return Array._new(res) + return Array._new(res, device=x.device) # Note: the descending keyword argument is new in this function def sort( @@ -51,4 +51,4 @@ def sort( res = np.sort(x._array, axis=axis, kind=kind) if descending: res = np.flip(res, axis=axis) - return Array._new(res) + return Array._new(res, device=x.device) diff --git a/array_api_strict/tests/test_sorting_functions.py b/array_api_strict/tests/test_sorting_functions.py index c479260..716a651 100644 --- a/array_api_strict/tests/test_sorting_functions.py +++ b/array_api_strict/tests/test_sorting_functions.py @@ -21,3 +21,17 @@ def test_stable_desc_argsort(obj, axis, expected): x = xp.asarray(obj) out = xp.argsort(x, axis=axis, stable=True, descending=True) assert xp.all(out == xp.asarray(expected)) + + +def test_argsort_device(): + x = xp.asarray([1., 2., -1., 3.141], device=xp.Device("device1")) + y = xp.argsort(x) + + assert y.device == x.device + + +def test_sort_device(): + x = xp.asarray([1., 2., -1., 3.141], device=xp.Device("device1")) + y = xp.sort(x) + + assert y.device == x.device \ No newline at end of file From 3bc8199874a61b6c77fc77d4959827935b94ee43 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 3 Oct 2024 15:05:27 +0200 Subject: [PATCH 130/252] More multi-device support --- array_api_strict/_array_object.py | 1 + array_api_strict/_creation_functions.py | 5 ++- array_api_strict/_data_type_functions.py | 4 ++- array_api_strict/_fft.py | 34 +++++++++---------- array_api_strict/tests/test_device_support.py | 25 ++++++++++++++ 5 files changed, 50 insertions(+), 19 deletions(-) create mode 100644 array_api_strict/tests/test_device_support.py diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index b3e5dbc..e453962 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -62,6 +62,7 @@ def __hash__(self): CPU_DEVICE = Device() +ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2")) _default = object() diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index dba54cf..a215f59 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -32,11 +32,14 @@ def _supports_buffer_protocol(obj): def _check_device(device): # _array_object imports in this file are inside the functions to avoid # circular imports - from ._array_object import Device + from ._array_object import Device, ALL_DEVICES if device is not None and not isinstance(device, Device): raise ValueError(f"Unsupported device {device!r}") + if device not in ALL_DEVICES: + raise ValueError(f"Unsupported device {device!r}") + def asarray( obj: Union[ Array, diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 3405710..56d1a02 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -37,10 +37,12 @@ def astype( _check_device(device) else: raise TypeError("The device argument to astype requires at least version 2023.12 of the array API") + else: + device = x.device if not copy and dtype == x.dtype: return x - return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy)) + return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy), device=device) def broadcast_arrays(*arrays: Array) -> List[Array]: diff --git a/array_api_strict/_fft.py b/array_api_strict/_fft.py index 32b9551..5b67b0c 100644 --- a/array_api_strict/_fft.py +++ b/array_api_strict/_fft.py @@ -14,7 +14,7 @@ float32, complex64, ) -from ._array_object import Array, CPU_DEVICE +from ._array_object import Array, ALL_DEVICES from ._data_type_functions import astype from ._flags import requires_extension @@ -36,7 +36,7 @@ def fft( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in fft") - res = Array._new(np.fft.fft(x._array, n=n, axis=axis, norm=norm)) + res = Array._new(np.fft.fft(x._array, n=n, axis=axis, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -59,7 +59,7 @@ def ifft( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in ifft") - res = Array._new(np.fft.ifft(x._array, n=n, axis=axis, norm=norm)) + res = Array._new(np.fft.ifft(x._array, n=n, axis=axis, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -82,7 +82,7 @@ def fftn( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in fftn") - res = Array._new(np.fft.fftn(x._array, s=s, axes=axes, norm=norm)) + res = Array._new(np.fft.fftn(x._array, s=s, axes=axes, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -105,7 +105,7 @@ def ifftn( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in ifftn") - res = Array._new(np.fft.ifftn(x._array, s=s, axes=axes, norm=norm)) + res = Array._new(np.fft.ifftn(x._array, s=s, axes=axes, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -128,7 +128,7 @@ def rfft( """ if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in rfft") - res = Array._new(np.fft.rfft(x._array, n=n, axis=axis, norm=norm)) + res = Array._new(np.fft.rfft(x._array, n=n, axis=axis, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == float32: @@ -151,7 +151,7 @@ def irfft( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in irfft") - res = Array._new(np.fft.irfft(x._array, n=n, axis=axis, norm=norm)) + res = Array._new(np.fft.irfft(x._array, n=n, axis=axis, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -174,7 +174,7 @@ def rfftn( """ if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in rfftn") - res = Array._new(np.fft.rfftn(x._array, s=s, axes=axes, norm=norm)) + res = Array._new(np.fft.rfftn(x._array, s=s, axes=axes, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == float32: @@ -197,7 +197,7 @@ def irfftn( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in irfftn") - res = Array._new(np.fft.irfftn(x._array, s=s, axes=axes, norm=norm)) + res = Array._new(np.fft.irfftn(x._array, s=s, axes=axes, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -220,7 +220,7 @@ def hfft( """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in hfft") - res = Array._new(np.fft.hfft(x._array, n=n, axis=axis, norm=norm)) + res = Array._new(np.fft.hfft(x._array, n=n, axis=axis, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == complex64: @@ -243,7 +243,7 @@ def ihfft( """ if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in ihfft") - res = Array._new(np.fft.ihfft(x._array, n=n, axis=axis, norm=norm)) + res = Array._new(np.fft.ihfft(x._array, n=n, axis=axis, norm=norm), device=x.device) # Note: np.fft functions improperly upcast float32 and complex64 to # complex128 if x.dtype == float32: @@ -257,9 +257,9 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar See its docstring for more information. """ - if device not in [CPU_DEVICE, None]: + if device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") - return Array._new(np.fft.fftfreq(n, d=d)) + return Array._new(np.fft.fftfreq(n, d=d), device=device) @requires_extension('fft') def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: @@ -268,9 +268,9 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A See its docstring for more information. """ - if device not in [CPU_DEVICE, None]: + if device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") - return Array._new(np.fft.rfftfreq(n, d=d)) + return Array._new(np.fft.rfftfreq(n, d=d), device=device) @requires_extension('fft') def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: @@ -281,7 +281,7 @@ def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in fftshift") - return Array._new(np.fft.fftshift(x._array, axes=axes)) + return Array._new(np.fft.fftshift(x._array, axes=axes), device=x.device) @requires_extension('fft') def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: @@ -292,7 +292,7 @@ def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: """ if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in ifftshift") - return Array._new(np.fft.ifftshift(x._array, axes=axes)) + return Array._new(np.fft.ifftshift(x._array, axes=axes), device=x.device) __all__ = [ "fft", diff --git a/array_api_strict/tests/test_device_support.py b/array_api_strict/tests/test_device_support.py new file mode 100644 index 0000000..c7fd340 --- /dev/null +++ b/array_api_strict/tests/test_device_support.py @@ -0,0 +1,25 @@ +import pytest + +import array_api_strict + + +@pytest.mark.parametrize("func_name", ("fft", "ifft", "fftn", "ifftn", "irfft", + "irfftn", "hfft", "fftshift", "ifftshift")) +def test_fft_device_support_complex(func_name): + func = getattr(array_api_strict.fft, func_name) + x = array_api_strict.asarray([1, 2.], + dtype=array_api_strict.complex64, + device=array_api_strict.Device("device1")) + y = func(x) + + assert x.device == y.device + + +@pytest.mark.parametrize("func_name", ("rfft", "rfftn", "ihfft")) +def test_fft_device_support_real(func_name): + func = getattr(array_api_strict.fft, func_name) + x = array_api_strict.asarray([1, 2.], + device=array_api_strict.Device("device1")) + y = func(x) + + assert x.device == y.device \ No newline at end of file From 032f3bb8e4623a9a080d89ffbcd6b11fafddb336 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 3 Oct 2024 15:15:27 +0200 Subject: [PATCH 131/252] Formatting --- array_api_strict/tests/test_device_support.py | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/array_api_strict/tests/test_device_support.py b/array_api_strict/tests/test_device_support.py index c7fd340..0f3d6b5 100644 --- a/array_api_strict/tests/test_device_support.py +++ b/array_api_strict/tests/test_device_support.py @@ -3,13 +3,27 @@ import array_api_strict -@pytest.mark.parametrize("func_name", ("fft", "ifft", "fftn", "ifftn", "irfft", - "irfftn", "hfft", "fftshift", "ifftshift")) +@pytest.mark.parametrize( + "func_name", + ( + "fft", + "ifft", + "fftn", + "ifftn", + "irfft", + "irfftn", + "hfft", + "fftshift", + "ifftshift", + ), +) def test_fft_device_support_complex(func_name): func = getattr(array_api_strict.fft, func_name) - x = array_api_strict.asarray([1, 2.], - dtype=array_api_strict.complex64, - device=array_api_strict.Device("device1")) + x = array_api_strict.asarray( + [1, 2.0], + dtype=array_api_strict.complex64, + device=array_api_strict.Device("device1"), + ) y = func(x) assert x.device == y.device @@ -18,8 +32,7 @@ def test_fft_device_support_complex(func_name): @pytest.mark.parametrize("func_name", ("rfft", "rfftn", "ihfft")) def test_fft_device_support_real(func_name): func = getattr(array_api_strict.fft, func_name) - x = array_api_strict.asarray([1, 2.], - device=array_api_strict.Device("device1")) + x = array_api_strict.asarray([1, 2.0], device=array_api_strict.Device("device1")) y = func(x) - assert x.device == y.device \ No newline at end of file + assert x.device == y.device From 724e0714a2cdd9aea89e2069bf4229e4df2cea51 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 3 Oct 2024 15:23:19 +0200 Subject: [PATCH 132/252] Add multi-device test for take --- array_api_strict/_creation_functions.py | 2 +- array_api_strict/_indexing_functions.py | 4 +++- .../tests/test_indexing_functions.py | 20 +++++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index a215f59..18cbdfd 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -37,7 +37,7 @@ def _check_device(device): if device is not None and not isinstance(device, Device): raise ValueError(f"Unsupported device {device!r}") - if device not in ALL_DEVICES: + if device is not None and device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") def asarray( diff --git a/array_api_strict/_indexing_functions.py b/array_api_strict/_indexing_functions.py index 316a3a7..7506d4f 100644 --- a/array_api_strict/_indexing_functions.py +++ b/array_api_strict/_indexing_functions.py @@ -22,4 +22,6 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: raise TypeError("Only integer dtypes are allowed in indexing") if indices.ndim != 1: raise ValueError("Only 1-dim indices array is supported") - return Array._new(np.take(x._array, indices._array, axis=axis)) + if x.device != indices.device: + raise RuntimeError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") + return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device) diff --git a/array_api_strict/tests/test_indexing_functions.py b/array_api_strict/tests/test_indexing_functions.py index fabe688..6d65239 100644 --- a/array_api_strict/tests/test_indexing_functions.py +++ b/array_api_strict/tests/test_indexing_functions.py @@ -22,3 +22,23 @@ def test_take_function(x, indices, axis, expected): indices = xp.asarray(indices) out = xp.take(x, indices, axis=axis) assert xp.all(out == xp.asarray(expected)) + + +def test_take_device(): + x = xp.asarray([2, 3]) + indices = xp.asarray([1, 1, 0]) + xp.take(x, indices) + + x = xp.asarray([2, 3]) + indices = xp.asarray([1, 1, 0], device=xp.Device("device1")) + with pytest.raises(RuntimeError, match="Arrays from two different devices"): + xp.take(x, indices) + + x = xp.asarray([2, 3], device=xp.Device("device1")) + indices = xp.asarray([1, 1, 0]) + with pytest.raises(RuntimeError, match="Arrays from two different devices"): + xp.take(x, indices) + + x = xp.asarray([2, 3], device=xp.Device("device1")) + indices = xp.asarray([1, 1, 0], device=xp.Device("device1")) + xp.take(x, indices) \ No newline at end of file From e0b2a64f8b90489898d5ec14aef4edb6ca5f76f5 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 3 Oct 2024 16:27:56 +0200 Subject: [PATCH 133/252] Multi-device support in linear algebra functions --- array_api_strict/_info.py | 4 ++-- array_api_strict/_linear_algebra_functions.py | 17 +++++++++++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index cfcff8b..3ed7fb2 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -6,7 +6,7 @@ from typing import Optional, Union, Tuple, List from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info -from ._array_object import CPU_DEVICE, Device +from ._array_object import ALL_DEVICES, CPU_DEVICE from ._flags import get_array_api_strict_flags, requires_api_version from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 @@ -121,7 +121,7 @@ def dtypes( @requires_api_version('2023.12') def devices() -> List[device]: - return [CPU_DEVICE, Device("device1"), Device("device2")] + return list(ALL_DEVICES) __all__ = [ "capabilities", diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py index dcb654d..6746bc7 100644 --- a/array_api_strict/_linear_algebra_functions.py +++ b/array_api_strict/_linear_algebra_functions.py @@ -30,7 +30,10 @@ def matmul(x1: Array, x2: Array, /) -> Array: if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in matmul') - return Array._new(np.matmul(x1._array, x2._array)) + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + + return Array._new(np.matmul(x1._array, x2._array), device=x1.device) # Note: tensordot is the numpy top-level namespace but not in np.linalg @@ -41,14 +44,17 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in tensordot') - return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + + return Array._new(np.tensordot(x1._array, x2._array, axes=axes), device=x1.device) # Note: this function is new in the array API spec. Unlike transpose, it only # transposes the last two axes. def matrix_transpose(x: Array, /) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") - return Array._new(np.swapaxes(x._array, -1, -2)) + return Array._new(np.swapaxes(x._array, -1, -2), device=x.device) # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: @@ -61,6 +67,9 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: elif axis < min(-1, -x1.ndim, -x2.ndim): raise ValueError("axis is out of bounds for x1 and x2") + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + # In versions of the standard prior to 2023.12, vecdot applied axis after # broadcasting. This is different from applying it before broadcasting # when axis is nonnegative. The below code keeps this behavior for @@ -78,4 +87,4 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x2_ = np.moveaxis(x2_, axis, -1) res = x1_[..., None, :] @ x2_[..., None] - return Array._new(res[..., 0, 0]) + return Array._new(res[..., 0, 0], device=x1.device) From 932332403455de7c228b8084474ea0f6e0b7033c Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 3 Oct 2024 16:40:16 +0200 Subject: [PATCH 134/252] Multi-device support for array manipulation --- array_api_strict/_manipulation_functions.py | 27 ++++++++++++--------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 7652028..2162ffd 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -25,8 +25,11 @@ def concat( # Note: Casting rules here are different from the np.concatenate default # (no for scalars with axis=None, no cross-kind casting) dtype = result_type(*arrays) + if len({a.device for a in arrays}) > 1: + raise ValueError("concat inputs must all be on the same device") + arrays = tuple(a._array for a in arrays) - return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype)) + return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype), device=arrays[0].device) def expand_dims(x: Array, /, *, axis: int) -> Array: @@ -35,7 +38,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array: See its docstring for more information. """ - return Array._new(np.expand_dims(x._array, axis)) + return Array._new(np.expand_dims(x._array, axis), device=x.device) def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: @@ -44,7 +47,7 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> See its docstring for more information. """ - return Array._new(np.flip(x._array, axis=axis)) + return Array._new(np.flip(x._array, axis=axis), device=x.device) @requires_api_version('2023.12') def moveaxis( @@ -58,7 +61,7 @@ def moveaxis( See its docstring for more information. """ - return Array._new(np.moveaxis(x._array, source, destination)) + return Array._new(np.moveaxis(x._array, source, destination), device=x.device) # Note: The function name is different here (see also matrix_transpose). # Unlike transpose(), the axes argument is required. @@ -68,7 +71,7 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: See its docstring for more information. """ - return Array._new(np.transpose(x._array, axes)) + return Array._new(np.transpose(x._array, axes), device=x.device) @requires_api_version('2023.12') def repeat( @@ -94,7 +97,7 @@ def repeat( else: raise TypeError("repeats must be an int or array") - return Array._new(np.repeat(x._array, repeats, axis=axis)) + return Array._new(np.repeat(x._array, repeats, axis=axis), device=x.device) # Note: the optional argument is called 'shape', not 'newshape' def reshape(x: Array, @@ -117,7 +120,7 @@ def reshape(x: Array, if copy is False and not np.shares_memory(data, reshaped): raise AttributeError("Incompatible shape for in-place modification.") - return Array._new(reshaped) + return Array._new(reshaped, device=x.device) def roll( @@ -132,7 +135,7 @@ def roll( See its docstring for more information. """ - return Array._new(np.roll(x._array, shift, axis=axis)) + return Array._new(np.roll(x._array, shift, axis=axis), device=x.device) def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: @@ -141,7 +144,7 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: See its docstring for more information. """ - return Array._new(np.squeeze(x._array, axis=axis)) + return Array._new(np.squeeze(x._array, axis=axis), device=x.device) def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: @@ -152,8 +155,10 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> """ # Call result type here just to raise on disallowed type combinations result_type(*arrays) + if len({a.device for a in arrays}) > 1: + raise ValueError("concat inputs must all be on the same device") arrays = tuple(a._array for a in arrays) - return Array._new(np.stack(arrays, axis=axis)) + return Array._new(np.stack(arrays, axis=axis), device=arrays[0].device) @requires_api_version('2023.12') @@ -166,7 +171,7 @@ def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array: # Note: NumPy allows repetitions to be an int or array if not isinstance(repetitions, tuple): raise TypeError("repetitions must be a tuple") - return Array._new(np.tile(x._array, repetitions)) + return Array._new(np.tile(x._array, repetitions), device=x.device) # Note: this function is new @requires_api_version('2023.12') From ff37de742ebd4d12f935aaa9738a77fa451b4a26 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 3 Oct 2024 16:49:26 +0200 Subject: [PATCH 135/252] Add multi-device support for searching --- array_api_strict/_searching_functions.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 7314895..2b52889 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -19,7 +19,7 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - """ if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in argmax") - return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims)), device=x.device) def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: @@ -30,7 +30,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - """ if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in argmin") - return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims)), device=x.device) @requires_data_dependent_shapes @@ -61,12 +61,16 @@ def searchsorted( """ if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in searchsorted") + + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + sorter = sorter._array if sorter is not None else None # TODO: The sort order of nans and signed zeros is implementation # dependent. Should we error/warn if they are present? # x1 must be 1-D, but NumPy already requires this. - return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter)) + return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device) def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ @@ -76,5 +80,9 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) + + if len({a.device for a in (condition, x1, x2)}) > 1: + raise ValueError("where inputs must all be on the same device") + x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.where(condition._array, x1._array, x2._array)) + return Array._new(np.where(condition._array, x1._array, x2._array), device=x1.device) From bae74820202dad57059b02440e78a203e579f009 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 3 Oct 2024 17:05:40 +0200 Subject: [PATCH 136/252] Add multi-device support to stats and sets --- array_api_strict/_set_functions.py | 15 ++++++++------- array_api_strict/_statistical_functions.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/array_api_strict/_set_functions.py b/array_api_strict/_set_functions.py index e6ca939..7bd5bad 100644 --- a/array_api_strict/_set_functions.py +++ b/array_api_strict/_set_functions.py @@ -55,10 +55,10 @@ def unique_all(x: Array, /) -> UniqueAllResult: # See https://github.com/numpy/numpy/issues/20638 inverse_indices = inverse_indices.reshape(x.shape) return UniqueAllResult( - Array._new(values), - Array._new(indices), - Array._new(inverse_indices), - Array._new(counts), + Array._new(values, device=x.device), + Array._new(indices, device=x.device), + Array._new(inverse_indices, device=x.device), + Array._new(counts, device=x.device), ) @@ -72,7 +72,7 @@ def unique_counts(x: Array, /) -> UniqueCountsResult: equal_nan=False, ) - return UniqueCountsResult(*[Array._new(i) for i in res]) + return UniqueCountsResult(*[Array._new(i, device=x.device) for i in res]) @requires_data_dependent_shapes @@ -92,7 +92,8 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult: # np.unique() flattens inverse indices, but they need to share x's shape # See https://github.com/numpy/numpy/issues/20638 inverse_indices = inverse_indices.reshape(x.shape) - return UniqueInverseResult(Array._new(values), Array._new(inverse_indices)) + return UniqueInverseResult(Array._new(values, device=x.device), + Array._new(inverse_indices, device=x.device)) @requires_data_dependent_shapes @@ -109,4 +110,4 @@ def unique_values(x: Array, /) -> Array: return_inverse=False, equal_nan=False, ) - return Array._new(res) + return Array._new(res, device=x.device) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 39e3736..6ea9746 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -44,7 +44,7 @@ def cumulative_sum( if axis < 0: axis += x.ndim x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis) - return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype)) + return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device) def max( x: Array, @@ -55,7 +55,7 @@ def max( ) -> Array: if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in max") - return Array._new(np.max(x._array, axis=axis, keepdims=keepdims)) + return Array._new(np.max(x._array, axis=axis, keepdims=keepdims), device=x.device) def mean( @@ -67,7 +67,7 @@ def mean( ) -> Array: if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in mean") - return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims)) + return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims), device=x.device) def min( @@ -79,7 +79,7 @@ def min( ) -> Array: if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in min") - return Array._new(np.min(x._array, axis=axis, keepdims=keepdims)) + return Array._new(np.min(x._array, axis=axis, keepdims=keepdims), device=x.device) def prod( @@ -104,7 +104,7 @@ def prod( dtype = np.complex128 else: dtype = dtype._np_dtype - return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims)) + return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims), device=x.device) def std( @@ -118,7 +118,7 @@ def std( # Note: the keyword argument correction is different here if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in std") - return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims)) + return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device) def sum( @@ -143,7 +143,7 @@ def sum( dtype = np.complex128 else: dtype = dtype._np_dtype - return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims)) + return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims), device=x.device) def var( @@ -157,4 +157,4 @@ def var( # Note: the keyword argument correction is different here if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in var") - return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims)) + return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device) From cca1785f2710e1c0c38d6a1a6e2960ccd2fe4ff3 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 3 Oct 2024 17:11:11 +0200 Subject: [PATCH 137/252] Add multi-device support for utils --- array_api_strict/_utility_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_utility_functions.py b/array_api_strict/_utility_functions.py index c91fa58..0d44ecb 100644 --- a/array_api_strict/_utility_functions.py +++ b/array_api_strict/_utility_functions.py @@ -21,7 +21,7 @@ def all( See its docstring for more information. """ - return Array._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims)), device=x.device) def any( @@ -36,4 +36,4 @@ def any( See its docstring for more information. """ - return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims))) + return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims)), device=x.device) From a96c497232ba99339b23023a599d566194ee1e90 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 7 Oct 2024 14:10:27 +0200 Subject: [PATCH 138/252] More FFT multi-device --- array_api_strict/_array_object.py | 29 +++++++++++++++++++++-------- array_api_strict/_fft.py | 4 ++-- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index e453962..35dbf4e 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -859,12 +859,13 @@ def __rshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __rshift__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__rshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__rshift__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __setitem__( self, @@ -889,12 +890,13 @@ def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __sub__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__sub__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__sub__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) # PEP 484 requires int to be a subtype of float, but __truediv__ should # not accept int. @@ -902,28 +904,31 @@ def __truediv__(self: Array, other: Union[float, Array], /) -> Array: """ Performs the operation __truediv__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "floating-point", "__truediv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__truediv__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __xor__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__xor__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __iadd__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __iadd__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__iadd__") if other is NotImplemented: return other @@ -934,17 +939,19 @@ def __radd__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __radd__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__radd__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__radd__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __iand__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __iand__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__") if other is NotImplemented: return other @@ -955,17 +962,19 @@ def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __rand__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__rand__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ifloordiv__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__") if other is NotImplemented: return other @@ -976,17 +985,19 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __rfloordiv__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__rfloordiv__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ilshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __ilshift__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__ilshift__") if other is NotImplemented: return other @@ -997,17 +1008,19 @@ def __rlshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __rlshift__. """ + other = self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__rlshift__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__rlshift__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __imatmul__(self: Array, other: Array, /) -> Array: """ Performs the operation __imatmul__. """ + other = self._check_device(other) # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") diff --git a/array_api_strict/_fft.py b/array_api_strict/_fft.py index 5b67b0c..4b0ceb6 100644 --- a/array_api_strict/_fft.py +++ b/array_api_strict/_fft.py @@ -257,7 +257,7 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar See its docstring for more information. """ - if device not in ALL_DEVICES: + if device is not None and device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.fft.fftfreq(n, d=d), device=device) @@ -268,7 +268,7 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A See its docstring for more information. """ - if device not in ALL_DEVICES: + if device is not None and device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") return Array._new(np.fft.rfftfreq(n, d=d), device=device) From 58334e58ec89e38a005cd007001d7df206e41045 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 7 Oct 2024 14:16:45 +0200 Subject: [PATCH 139/252] Fix weird ruff error --- array_api_strict/tests/test_elementwise_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index fd8c3ae..104b793 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,5 +1,7 @@ from inspect import signature, getmodule +import pytest + from numpy.testing import assert_raises from .. import asarray, _elementwise_functions @@ -17,8 +19,6 @@ ) from .._flags import set_array_api_strict_flags -import pytest - import array_api_strict From 9c5436c8e7bdf29963abf6cafb87796b384afa19 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 7 Oct 2024 14:33:22 +0200 Subject: [PATCH 140/252] New default version --- array_api_strict/tests/test_elementwise_functions.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index e2b63d4..de11edf 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,7 +1,5 @@ from inspect import signature, getmodule -import pytest - from numpy.testing import assert_raises from .. import asarray, _elementwise_functions @@ -128,8 +126,7 @@ def _array_vals(dtypes): yield asarray(1., dtype=d) # Use the latest version of the standard so all functions are included - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version="2023.12") + set_array_api_strict_flags(api_version="2023.12") for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] From 0dbabcc15184fc76bdca3186f3cfbbb51e3c0f95 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 7 Oct 2024 16:04:17 +0200 Subject: [PATCH 141/252] Fix result device --- array_api_strict/_manipulation_functions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 2162ffd..0dd4343 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -27,9 +27,10 @@ def concat( dtype = result_type(*arrays) if len({a.device for a in arrays}) > 1: raise ValueError("concat inputs must all be on the same device") + result_device = arrays[0].device arrays = tuple(a._array for a in arrays) - return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype), device=arrays[0].device) + return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype), device=result_device) def expand_dims(x: Array, /, *, axis: int) -> Array: @@ -157,8 +158,9 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> result_type(*arrays) if len({a.device for a in arrays}) > 1: raise ValueError("concat inputs must all be on the same device") + result_device = arrays[0].device arrays = tuple(a._array for a in arrays) - return Array._new(np.stack(arrays, axis=axis), device=arrays[0].device) + return Array._new(np.stack(arrays, axis=axis), device=result_device) @requires_api_version('2023.12') From 6d780a8ab542af6e09477416a2843f921878c5a4 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 15 Oct 2024 11:28:13 -0600 Subject: [PATCH 142/252] Fix issue with repeat() NumPy does not allow repeats to be uint64 because it refuses to downcast it. Technically it worked before because we implement __array__ and repeat does manually cast in that case. I'm not really sure we should be supporting __array__ actually. --- array_api_strict/_manipulation_functions.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 7652028..702d259 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -2,8 +2,8 @@ from ._array_object import Array from ._creation_functions import asarray -from ._data_type_functions import result_type -from ._dtypes import _integer_dtypes +from ._data_type_functions import astype, result_type +from ._dtypes import _integer_dtypes, int64, uint64 from ._flags import requires_api_version, get_array_api_strict_flags from typing import TYPE_CHECKING @@ -94,7 +94,13 @@ def repeat( else: raise TypeError("repeats must be an int or array") - return Array._new(np.repeat(x._array, repeats, axis=axis)) + if repeats.dtype == uint64: + # NumPy does not allow uint64 because can't be cast down to x.dtype + # with 'safe' casting. However, repeats values larger than 2**63 are + # infeasable, and even if they are present by mistake, this will + # lead to underflow and an error. + repeats = astype(repeats, int64) + return Array._new(np.repeat(x._array, repeats._array, axis=axis)) # Note: the optional argument is called 'shape', not 'newshape' def reshape(x: Array, From ff126d7baf60604bb86dad50562979a829c9b94b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 15 Oct 2024 14:25:06 -0600 Subject: [PATCH 143/252] Fix some functions that were missing ._array They were only working because we define __array__, which may be going away (#67). --- array_api_strict/_elementwise_functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index b39bd86..ab1cbb7 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -317,7 +317,7 @@ def conj(x: Array, /) -> Array: """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in conj") - return Array._new(np.conj(x)) + return Array._new(np.conj(x._array)) @requires_api_version('2023.12') def copysign(x1: Array, x2: Array, /) -> Array: @@ -480,7 +480,7 @@ def imag(x: Array, /) -> Array: """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in imag") - return Array._new(np.imag(x)) + return Array._new(np.imag(x._array)) def isfinite(x: Array, /) -> Array: @@ -755,7 +755,7 @@ def real(x: Array, /) -> Array: """ if x.dtype not in _complex_floating_dtypes: raise TypeError("Only complex floating-point dtypes are allowed in real") - return Array._new(np.real(x)) + return Array._new(np.real(x._array)) def remainder(x1: Array, x2: Array, /) -> Array: From 38551c6d323fa610a6b18612ba5793e4d4ac2a87 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 15 Oct 2024 14:34:58 -0600 Subject: [PATCH 144/252] Remove __array__ This makes it raise an exception, since it isn't supported by the standard (if we just leave it unimplemented, then np.asarray() returns an object dtype array, which is not good). There is one issue here from the test suite, which is that this breaks the logic in asarray() for converting lists of array_api_strict 0-D arrays. I'm not yet sure what to do about that. Fixes #67. --- array_api_strict/_array_object.py | 31 +++++++-------------- array_api_strict/tests/test_array_object.py | 7 ----- 2 files changed, 10 insertions(+), 28 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index d8ed018..e7b53d9 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -125,28 +125,17 @@ def __repr__(self: Array, /) -> str: mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) return prefix + mid + suffix - # This function is not required by the spec, but we implement it here for - # convenience so that np.asarray(array_api_strict.Array) will work. + # Disallow __array__, meaning calling `np.func()` on an array_api_strict + # array will give an error. If we don't explicitly disallow it, NumPy + # defaults to creating an object dtype array, which would lead to + # confusing error messages at best and surprising bugs at worst. + # + # The alternative of course is to just support __array__, which is what we + # used to do. But this isn't actually supported by the standard, so it can + # lead to code assuming np.asarray(other_array) would always work in the + # standard. def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]: - """ - Warning: this method is NOT part of the array API spec. Implementers - of other libraries need not include it, and users should not assume it - will be present in other implementations. - - """ - # copy keyword is new in 2.0.0; for older versions don't use it - # retry without that keyword. - if np.__version__[0] < '2': - return np.asarray(self._array, dtype=dtype) - elif np.__version__.startswith('2.0.0-dev0'): - # Handle dev version for which we can't know based on version - # number whether or not the copy keyword is supported. - try: - return np.asarray(self._array, dtype=dtype, copy=copy) - except TypeError: - return np.asarray(self._array, dtype=dtype) - else: - return np.asarray(self._array, dtype=dtype, copy=copy) + raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported") # These are various helper functions to make the array behavior match the # spec in places where it either deviates from or is more strict than diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index dad6696..dd3f6c2 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -342,13 +342,6 @@ def test_array_properties(): assert isinstance(b.mT, Array) assert b.mT.shape == (3, 2) -def test___array__(): - a = ones((2, 3), dtype=int16) - assert np.asarray(a) is a._array - b = np.asarray(a, dtype=np.float64) - assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64))) - assert b.dtype == np.float64 - def test_allow_newaxis(): a = ones(5) indexed_a = a[None, :] From 8e6365bca3df0184d0c8a6a438bcddb1dd70f365 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Wed, 16 Oct 2024 15:27:08 +0200 Subject: [PATCH 145/252] Make device= a required argument to create an Array --- array_api_strict/_array_object.py | 99 +++++++++++++----------- array_api_strict/_creation_functions.py | 4 +- array_api_strict/_data_type_functions.py | 4 +- array_api_strict/_linalg.py | 51 +++++++----- array_api_strict/_searching_functions.py | 2 +- 5 files changed, 89 insertions(+), 71 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 35dbf4e..16d5d1a 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -86,7 +86,7 @@ class Array: # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. @classmethod - def _new(cls, x, /, device=None): + def _new(cls, x, /, device): """ This is a private method for initializing the array API Array object. @@ -218,11 +218,10 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor def _check_device(self, other): """Check that other is on a device compatible with the current array""" if isinstance(other, (int, complex, float, bool)): - return other + return elif isinstance(other, Array): if self.device != other.device: raise RuntimeError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") - return other # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): @@ -275,7 +274,7 @@ def _promote_scalar(self, scalar): # behavior for integers within the bounds of the integer dtype. # Outside of those bounds we use the default NumPy behavior (either # cast or raise OverflowError). - return Array._new(np.array(scalar, dtype=self.dtype._np_dtype)) + return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=CPU_DEVICE) @staticmethod def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: @@ -307,9 +306,9 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: # performant. broadcast_to(x1._array, x2.shape) is much slower. We # could also manually type promote x2, but that is more complicated # and about the same performance as this. - x1 = Array._new(x1._array[None]) + x1 = Array._new(x1._array[None], device=x1.device) elif x2.ndim == 0 and x1.ndim != 0: - x2 = Array._new(x2._array[None]) + x2 = Array._new(x2._array[None], device=x2.device) return (x1, x2) # Note: A large fraction of allowed indices are disallowed here (see the @@ -493,13 +492,13 @@ def __abs__(self: Array, /) -> Array: if self.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in __abs__") res = self._array.__abs__() - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __add__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __add__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__add__") if other is NotImplemented: return other @@ -511,7 +510,7 @@ def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __and__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__and__") if other is NotImplemented: return other @@ -601,7 +600,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ Performs the operation __eq__. """ - other = self._check_device(other) + self._check_device(other) # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. other = self._check_allowed_dtypes(other, "all", "__eq__") @@ -627,7 +626,7 @@ def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __floordiv__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__") if other is NotImplemented: return other @@ -639,7 +638,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ge__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__ge__") if other is NotImplemented: return other @@ -675,13 +674,13 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __gt__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__gt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__gt__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=other.device) def __int__(self: Array, /) -> int: """ @@ -730,7 +729,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __le__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__le__") if other is NotImplemented: return other @@ -742,7 +741,7 @@ def __lshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __lshift__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__lshift__") if other is NotImplemented: return other @@ -754,7 +753,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __lt__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__lt__") if other is NotImplemented: return other @@ -766,7 +765,7 @@ def __matmul__(self: Array, other: Array, /) -> Array: """ Performs the operation __matmul__. """ - other = self._check_device(other) + self._check_device(other) # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._check_allowed_dtypes(other, "numeric", "__matmul__") @@ -779,7 +778,7 @@ def __mod__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __mod__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__mod__") if other is NotImplemented: return other @@ -791,7 +790,7 @@ def __mul__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __mul__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__mul__") if other is NotImplemented: return other @@ -803,7 +802,7 @@ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: """ Performs the operation __ne__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other @@ -824,7 +823,7 @@ def __or__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __or__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__or__") if other is NotImplemented: return other @@ -847,7 +846,7 @@ def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: """ from ._elementwise_functions import pow - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__pow__") if other is NotImplemented: return other @@ -859,7 +858,7 @@ def __rshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __rshift__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__rshift__") if other is NotImplemented: return other @@ -890,7 +889,7 @@ def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __sub__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__sub__") if other is NotImplemented: return other @@ -904,7 +903,7 @@ def __truediv__(self: Array, other: Union[float, Array], /) -> Array: """ Performs the operation __truediv__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "floating-point", "__truediv__") if other is NotImplemented: return other @@ -916,7 +915,7 @@ def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __xor__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__") if other is NotImplemented: return other @@ -928,7 +927,7 @@ def __iadd__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __iadd__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__iadd__") if other is NotImplemented: return other @@ -939,7 +938,7 @@ def __radd__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __radd__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__radd__") if other is NotImplemented: return other @@ -951,7 +950,7 @@ def __iand__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __iand__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__") if other is NotImplemented: return other @@ -962,7 +961,7 @@ def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __rand__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__") if other is NotImplemented: return other @@ -974,7 +973,7 @@ def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ifloordiv__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__") if other is NotImplemented: return other @@ -985,7 +984,7 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __rfloordiv__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__") if other is NotImplemented: return other @@ -997,7 +996,7 @@ def __ilshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __ilshift__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__ilshift__") if other is NotImplemented: return other @@ -1008,7 +1007,7 @@ def __rlshift__(self: Array, other: Union[int, Array], /) -> Array: """ Performs the operation __rlshift__. """ - other = self._check_device(other) + self._check_device(other) other = self._check_allowed_dtypes(other, "integer", "__rlshift__") if other is NotImplemented: return other @@ -1020,14 +1019,14 @@ def __imatmul__(self: Array, other: Array, /) -> Array: """ Performs the operation __imatmul__. """ - other = self._check_device(other) # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") if other is NotImplemented: return other + self._check_device(other) res = self._array.__imatmul__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __rmatmul__(self: Array, other: Array, /) -> Array: """ @@ -1038,8 +1037,9 @@ def __rmatmul__(self: Array, other: Array, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__") if other is NotImplemented: return other + self._check_device(other) res = self._array.__rmatmul__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __imod__(self: Array, other: Union[int, float, Array], /) -> Array: """ @@ -1058,9 +1058,10 @@ def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array: other = self._check_allowed_dtypes(other, "real numeric", "__rmod__") if other is NotImplemented: return other + self._check_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rmod__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __imul__(self: Array, other: Union[int, float, Array], /) -> Array: """ @@ -1079,9 +1080,10 @@ def __rmul__(self: Array, other: Union[int, float, Array], /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rmul__") if other is NotImplemented: return other + self._check_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rmul__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array: """ @@ -1097,12 +1099,13 @@ def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array: """ Performs the operation __ror__. """ + self._check_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) res = self._array.__ror__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array: """ @@ -1144,9 +1147,10 @@ def __rrshift__(self: Array, other: Union[int, Array], /) -> Array: other = self._check_allowed_dtypes(other, "integer", "__rrshift__") if other is NotImplemented: return other + self._check_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rrshift__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __isub__(self: Array, other: Union[int, float, Array], /) -> Array: """ @@ -1165,9 +1169,10 @@ def __rsub__(self: Array, other: Union[int, float, Array], /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rsub__") if other is NotImplemented: return other + self._check_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rsub__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __itruediv__(self: Array, other: Union[float, Array], /) -> Array: """ @@ -1186,9 +1191,10 @@ def __rtruediv__(self: Array, other: Union[float, Array], /) -> Array: other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__") if other is NotImplemented: return other + self._check_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rtruediv__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array: """ @@ -1207,9 +1213,10 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array: other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__") if other is NotImplemented: return other + self._check_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rxor__(other._array) - return self.__class__._new(res) + return self.__class__._new(res, device=self.device) def to_device(self: Array, device: Device, /, stream: None = None) -> Array: if stream is not None: @@ -1279,4 +1286,4 @@ def T(self) -> Array: # https://data-apis.org/array-api/latest/API_specification/array_object.html#t if self.ndim != 2: raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.") - return self.__class__._new(self._array.T) + return self.__class__._new(self._array.T, device=self.device) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 18cbdfd..a46c7a8 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -79,7 +79,7 @@ def asarray( new_array = np.array(obj._array, copy=copy, dtype=_np_dtype) if new_array is not obj._array: raise ValueError("Unable to avoid copy while creating an array from given array.") - return Array._new(new_array) + return Array._new(new_array, device=device) elif _supports_buffer_protocol(obj): # Buffer protocol will always support no-copy return Array._new(np.array(obj, copy=copy, dtype=_np_dtype), device=device) @@ -211,7 +211,7 @@ def from_dlpack( if copy not in [_default, None]: raise NotImplementedError("The copy argument to from_dlpack is not yet implemented") - return Array._new(np.from_dlpack(x)) + return Array._new(np.from_dlpack(x), device=device) def full( diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 56d1a02..046dfc7 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -54,7 +54,7 @@ def broadcast_arrays(*arrays: Array) -> List[Array]: from ._array_object import Array return [ - Array._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays]) + Array._new(array, device=arrays[0].device) for array in np.broadcast_arrays(*[a._array for a in arrays]) ] @@ -66,7 +66,7 @@ def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: """ from ._array_object import Array - return Array._new(np.broadcast_to(x._array, shape)) + return Array._new(np.broadcast_to(x._array, shape), device=x.device) def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: diff --git a/array_api_strict/_linalg.py b/array_api_strict/_linalg.py index bd11aa4..d364997 100644 --- a/array_api_strict/_linalg.py +++ b/array_api_strict/_linalg.py @@ -1,5 +1,7 @@ from __future__ import annotations +from functools import partial + from ._dtypes import ( _floating_dtypes, _numeric_dtypes, @@ -59,11 +61,11 @@ def cholesky(x: Array, /, *, upper: bool = False) -> Array: raise TypeError('Only floating-point dtypes are allowed in cholesky') L = np.linalg.cholesky(x._array) if upper: - U = Array._new(L).mT + U = Array._new(L, device=x.device).mT if U.dtype in [complex64, complex128]: U = conj(U) return U - return Array._new(L) + return Array._new(L, device=x.device) # Note: cross is the numpy top-level namespace, not np.linalg @requires_extension('linalg') @@ -81,6 +83,9 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if x1.shape[axis] != 3: raise ValueError('cross() dimension must equal 3') + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if get_array_api_strict_flags()['api_version'] >= '2023.12': if axis >= 0: raise ValueError("axis must be negative in cross") @@ -91,7 +96,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: # positive axis applied before or after broadcasting. NumPy applies # the axis before broadcasting. Since that behavior is what has always # been implemented here, we keep it for backwards compatibility. - return Array._new(np.cross(x1._array, x2._array, axis=axis)) + return Array._new(np.cross(x1._array, x2._array, axis=axis), device=x1.device) @requires_extension('linalg') def det(x: Array, /) -> Array: @@ -104,7 +109,7 @@ def det(x: Array, /) -> Array: # np.linalg.det. if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in det') - return Array._new(np.linalg.det(x._array)) + return Array._new(np.linalg.det(x._array), device=x.device) # Note: diagonal is the numpy top-level namespace, not np.linalg @requires_extension('linalg') @@ -116,7 +121,7 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array: """ # Note: diagonal always operates on the last two axes, whereas np.diagonal # operates on the first two axes by default - return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1)) + return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1), device=x.device) @requires_extension('linalg') def eigh(x: Array, /) -> EighResult: @@ -132,7 +137,7 @@ def eigh(x: Array, /) -> EighResult: # Note: the return type here is a namedtuple, which is different from # np.eigh, which only returns a tuple. - return EighResult(*map(Array._new, np.linalg.eigh(x._array))) + return EighResult(*map(partial(Array._new, device=x.device), np.linalg.eigh(x._array))) @requires_extension('linalg') @@ -147,7 +152,7 @@ def eigvalsh(x: Array, /) -> Array: if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in eigvalsh') - return Array._new(np.linalg.eigvalsh(x._array)) + return Array._new(np.linalg.eigvalsh(x._array), device=x.device) @requires_extension('linalg') def inv(x: Array, /) -> Array: @@ -161,7 +166,7 @@ def inv(x: Array, /) -> Array: if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in inv') - return Array._new(np.linalg.inv(x._array)) + return Array._new(np.linalg.inv(x._array), device=x.device) # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). @@ -181,7 +186,7 @@ def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in matrix_norm') - return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord)) + return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord), device=x.device) @requires_extension('linalg') @@ -197,7 +202,7 @@ def matrix_power(x: Array, n: int, /) -> Array: raise TypeError('Only floating-point dtypes are allowed for the first argument of matrix_power') # np.matrix_power already checks if n is an integer - return Array._new(np.linalg.matrix_power(x._array, n)) + return Array._new(np.linalg.matrix_power(x._array, n), device=x.device) # Note: the keyword argument name rtol is different from np.linalg.matrix_rank @requires_extension('linalg') @@ -220,7 +225,7 @@ def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> A # Note: this is different from np.linalg.matrix_rank, which does not multiply # the tolerance by the largest singular value. tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis] - return Array._new(np.count_nonzero(S > tol, axis=-1)) + return Array._new(np.count_nonzero(S > tol, axis=-1), device=x.device) # Note: outer is the numpy top-level namespace, not np.linalg @@ -240,7 +245,10 @@ def outer(x1: Array, x2: Array, /) -> Array: if x1.ndim != 1 or x2.ndim != 1: raise ValueError('The input arrays to outer must be 1-dimensional') - return Array._new(np.outer(x1._array, x2._array)) + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + + return Array._new(np.outer(x1._array, x2._array), device=x1.device) # Note: the keyword argument name rtol is different from np.linalg.pinv @requires_extension('linalg') @@ -259,7 +267,7 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: # default tolerance by max(M, N). if rtol is None: rtol = max(x.shape[-2:]) * finfo(x.dtype).eps - return Array._new(np.linalg.pinv(x._array, rcond=rtol)) + return Array._new(np.linalg.pinv(x._array, rcond=rtol), device=x.device) @requires_extension('linalg') def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: # noqa: F821 @@ -275,7 +283,7 @@ def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRRe # Note: the return type here is a namedtuple, which is different from # np.linalg.qr, which only returns a tuple. - return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode))) + return QRResult(*map(partial(Array._new, device=x.device), np.linalg.qr(x._array, mode=mode))) @requires_extension('linalg') def slogdet(x: Array, /) -> SlogdetResult: @@ -291,7 +299,7 @@ def slogdet(x: Array, /) -> SlogdetResult: # Note: the return type here is a namedtuple, which is different from # np.linalg.slogdet, which only returns a tuple. - return SlogdetResult(*map(Array._new, np.linalg.slogdet(x._array))) + return SlogdetResult(*map(partial(Array._new, device=x.device), np.linalg.slogdet(x._array))) # Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a # vector when it is exactly 1-dimensional. All other cases treat x2 as a stack @@ -348,7 +356,10 @@ def solve(x1: Array, x2: Array, /) -> Array: if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in solve') - return Array._new(_solve(x1._array, x2._array)) + if x1.device != x2.device: + raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + + return Array._new(_solve(x1._array, x2._array), device=x1.device) @requires_extension('linalg') def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: @@ -364,7 +375,7 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: the return type here is a namedtuple, which is different from # np.svd, which only returns a tuple. - return SVDResult(*map(Array._new, np.linalg.svd(x._array, full_matrices=full_matrices))) + return SVDResult(*map(partial(Array._new, device=x.device), np.linalg.svd(x._array, full_matrices=full_matrices))) # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). @@ -372,7 +383,7 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in svdvals') - return Array._new(np.linalg.svd(x._array, compute_uv=False)) + return Array._new(np.linalg.svd(x._array, compute_uv=False), device=x.device) # Note: trace is the numpy top-level namespace, not np.linalg @requires_extension('linalg') @@ -397,7 +408,7 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr dtype = dtype._np_dtype # Note: trace always operates on the last two axes, whereas np.trace # operates on the first two axes by default - return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype))) + return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype)), device=x.device) # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). @@ -437,7 +448,7 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No else: _axis = axis - res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord)) + res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord), device=x.device) if keepdims: # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 2b52889..2922c36 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -43,7 +43,7 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: # Note: nonzero is disallowed on 0-dimensional arrays if x.ndim == 0: raise ValueError("nonzero is not allowed on 0-dimensional arrays") - return tuple(Array._new(i) for i in np.nonzero(x._array)) + return tuple(Array._new(i, device=x.device) for i in np.nonzero(x._array)) @requires_api_version('2023.12') def searchsorted( From 78def190da4b0b8ef09e484eb68b4f3237bdbba6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 16 Oct 2024 13:09:49 -0600 Subject: [PATCH 146/252] Add device check to repeat() --- array_api_strict/_manipulation_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index a430563..4d3f276 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -93,6 +93,8 @@ def repeat( raise RuntimeError("repeat() with repeats as an array requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") if repeats.dtype not in _integer_dtypes: raise TypeError("The repeats array must have an integer dtype") + if x.device != repeats.device: + raise RuntimeError(f"Arrays from two different devices ({x.device} and {repeats.device}) can not be combined.") elif isinstance(repeats, int): repeats = asarray(repeats) else: From 33450f30b1bb3aca9ab8aab5d525e9ed199b61a8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 16 Oct 2024 13:14:48 -0600 Subject: [PATCH 147/252] Use ValueError for different device errors --- array_api_strict/_array_object.py | 2 +- array_api_strict/_elementwise_functions.py | 58 +++++++++---------- array_api_strict/_indexing_functions.py | 2 +- array_api_strict/_linalg.py | 6 +- array_api_strict/_linear_algebra_functions.py | 6 +- array_api_strict/_manipulation_functions.py | 2 +- array_api_strict/_searching_functions.py | 2 +- .../tests/test_indexing_functions.py | 6 +- 8 files changed, 42 insertions(+), 42 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 16d5d1a..cd9a360 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -221,7 +221,7 @@ def _check_device(self, other): return elif isinstance(other, Array): if self.device != other.device: - raise RuntimeError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 80a1b8f..761caff 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -62,7 +62,7 @@ def add(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in add") @@ -116,7 +116,7 @@ def atan2(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in atan2") # Call result type here just to raise on disallowed type combinations @@ -144,7 +144,7 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if ( x1.dtype not in _integer_or_boolean_dtypes @@ -165,7 +165,7 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError("Only integer dtypes are allowed in bitwise_left_shift") @@ -197,7 +197,7 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if ( x1.dtype not in _integer_or_boolean_dtypes @@ -218,7 +218,7 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: raise TypeError("Only integer dtypes are allowed in bitwise_right_shift") @@ -238,7 +238,7 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if ( x1.dtype not in _integer_or_boolean_dtypes @@ -280,9 +280,9 @@ def clip( See its docstring for more information. """ if isinstance(min, Array) and x.device != min.device: - raise RuntimeError(f"Arrays from two different devices ({x.device} and {min.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x.device} and {min.device}) can not be combined.") if isinstance(max, Array) and x.device != max.device: - raise RuntimeError(f"Arrays from two different devices ({x.device} and {max.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x.device} and {max.device}) can not be combined.") if (x.dtype not in _real_numeric_dtypes or isinstance(min, Array) and min.dtype not in _real_numeric_dtypes @@ -352,7 +352,7 @@ def copysign(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real numeric dtypes are allowed in copysign") @@ -390,7 +390,7 @@ def divide(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in divide") # Call result type here just to raise on disallowed type combinations @@ -406,7 +406,7 @@ def equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) @@ -456,7 +456,7 @@ def floor_divide(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in floor_divide") # Call result type here just to raise on disallowed type combinations @@ -472,7 +472,7 @@ def greater(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater") # Call result type here just to raise on disallowed type combinations @@ -488,7 +488,7 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in greater_equal") # Call result type here just to raise on disallowed type combinations @@ -504,7 +504,7 @@ def hypot(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in hypot") # Call result type here just to raise on disallowed type combinations @@ -563,7 +563,7 @@ def less(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less") # Call result type here just to raise on disallowed type combinations @@ -579,7 +579,7 @@ def less_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in less_equal") # Call result type here just to raise on disallowed type combinations @@ -639,7 +639,7 @@ def logaddexp(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in logaddexp") # Call result type here just to raise on disallowed type combinations @@ -655,7 +655,7 @@ def logical_and(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_and") # Call result type here just to raise on disallowed type combinations @@ -682,7 +682,7 @@ def logical_or(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_or") # Call result type here just to raise on disallowed type combinations @@ -698,7 +698,7 @@ def logical_xor(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: raise TypeError("Only boolean dtypes are allowed in logical_xor") # Call result type here just to raise on disallowed type combinations @@ -714,7 +714,7 @@ def maximum(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in maximum") # Call result type here just to raise on disallowed type combinations @@ -732,7 +732,7 @@ def minimum(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in minimum") # Call result type here just to raise on disallowed type combinations @@ -747,7 +747,7 @@ def multiply(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in multiply") # Call result type here just to raise on disallowed type combinations @@ -774,7 +774,7 @@ def not_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) @@ -800,7 +800,7 @@ def pow(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in pow") # Call result type here just to raise on disallowed type combinations @@ -827,7 +827,7 @@ def remainder(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in remainder") # Call result type here just to raise on disallowed type combinations @@ -921,7 +921,7 @@ def subtract(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in subtract") # Call result type here just to raise on disallowed type combinations diff --git a/array_api_strict/_indexing_functions.py b/array_api_strict/_indexing_functions.py index 7506d4f..c0f8e26 100644 --- a/array_api_strict/_indexing_functions.py +++ b/array_api_strict/_indexing_functions.py @@ -23,5 +23,5 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: if indices.ndim != 1: raise ValueError("Only 1-dim indices array is supported") if x.device != indices.device: - raise RuntimeError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device) diff --git a/array_api_strict/_linalg.py b/array_api_strict/_linalg.py index d364997..d341277 100644 --- a/array_api_strict/_linalg.py +++ b/array_api_strict/_linalg.py @@ -84,7 +84,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: raise ValueError('cross() dimension must equal 3') if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") if get_array_api_strict_flags()['api_version'] >= '2023.12': if axis >= 0: @@ -246,7 +246,7 @@ def outer(x1: Array, x2: Array, /) -> Array: raise ValueError('The input arrays to outer must be 1-dimensional') if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") return Array._new(np.outer(x1._array, x2._array), device=x1.device) @@ -357,7 +357,7 @@ def solve(x1: Array, x2: Array, /) -> Array: raise TypeError('Only floating-point dtypes are allowed in solve') if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") return Array._new(_solve(x1._array, x2._array), device=x1.device) diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py index 6746bc7..5ffdaa6 100644 --- a/array_api_strict/_linear_algebra_functions.py +++ b/array_api_strict/_linear_algebra_functions.py @@ -31,7 +31,7 @@ def matmul(x1: Array, x2: Array, /) -> Array: raise TypeError('Only numeric dtypes are allowed in matmul') if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") return Array._new(np.matmul(x1._array, x2._array), device=x1.device) @@ -45,7 +45,7 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], raise TypeError('Only numeric dtypes are allowed in tensordot') if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") return Array._new(np.tensordot(x1._array, x2._array, axes=axes), device=x1.device) @@ -68,7 +68,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: raise ValueError("axis is out of bounds for x1 and x2") if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") # In versions of the standard prior to 2023.12, vecdot applied axis after # broadcasting. This is different from applying it before broadcasting diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 4d3f276..d775835 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -94,7 +94,7 @@ def repeat( if repeats.dtype not in _integer_dtypes: raise TypeError("The repeats array must have an integer dtype") if x.device != repeats.device: - raise RuntimeError(f"Arrays from two different devices ({x.device} and {repeats.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x.device} and {repeats.device}) can not be combined.") elif isinstance(repeats, int): repeats = asarray(repeats) else: diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 2922c36..0d7c0c8 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -63,7 +63,7 @@ def searchsorted( raise TypeError("Only real numeric dtypes are allowed in searchsorted") if x1.device != x2.device: - raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") sorter = sorter._array if sorter is not None else None # TODO: The sort order of nans and signed zeros is implementation diff --git a/array_api_strict/tests/test_indexing_functions.py b/array_api_strict/tests/test_indexing_functions.py index 6d65239..f9fff58 100644 --- a/array_api_strict/tests/test_indexing_functions.py +++ b/array_api_strict/tests/test_indexing_functions.py @@ -31,14 +31,14 @@ def test_take_device(): x = xp.asarray([2, 3]) indices = xp.asarray([1, 1, 0], device=xp.Device("device1")) - with pytest.raises(RuntimeError, match="Arrays from two different devices"): + with pytest.raises(ValueError, match="Arrays from two different devices"): xp.take(x, indices) x = xp.asarray([2, 3], device=xp.Device("device1")) indices = xp.asarray([1, 1, 0]) - with pytest.raises(RuntimeError, match="Arrays from two different devices"): + with pytest.raises(ValueError, match="Arrays from two different devices"): xp.take(x, indices) x = xp.asarray([2, 3], device=xp.Device("device1")) indices = xp.asarray([1, 1, 0], device=xp.Device("device1")) - xp.take(x, indices) \ No newline at end of file + xp.take(x, indices) From c89d77a82e2ef7772c26f6cd10cd4b101ce3f90d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 18 Oct 2024 12:46:30 -0600 Subject: [PATCH 148/252] Add changelog entries for a 2.1 release --- docs/changelog.md | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 7b40fe3..1329d32 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,8 +1,32 @@ # Changelog -### 2.0.1 (2024-07-01) -## Minor Changes +## 2.1 (2024-10-18) + +## Major Changes + +- The default version of the array API standard is now 2023.12. 2022.12 can + still be enabled via the [flags API](array-api-strict-flags). + +- Added support for multiple fake "devices", so that code testing against + array-api-strict can check for proper device support. Currently there are + three "devices", the "CPU" device, which is the default devices, and two + pseudo "device" objects. This set of devices can be accessed with + `array_api_strict.__array_namespace_info__().devices()` (requires the array + API version to be set to 2023.12), and via the other array API APIs that + return devices (like `x.device`). These devices do not correspond to any + actual hardware and only exist for testing array API device semantics; for + instance, implicitly combining arrays on different devices results in an + exception. (Thanks to [@betatim](https://github.com/betatim)). + +### Minor Changes + +- Avoid implicitly relying on `__array__` in some places. These changes should + not be usef visible. + +## 2.0.1 (2024-07-01) + +### Minor Changes - Re-allow iteration on 1-D arrays. A change from 2.0 fixed iter() raising on n-D arrays but also made 1-D arrays raise. The standard does not explicitly From 349c4ff74686372c3873ff646b615e0776140bf7 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 21 Oct 2024 12:49:24 +0200 Subject: [PATCH 149/252] Propagate input array's device in asarray --- array_api_strict/_creation_functions.py | 6 ++++++ array_api_strict/tests/test_creation_functions.py | 13 ++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index a46c7a8..7924a85 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -67,6 +67,8 @@ def asarray( if dtype is not None: _np_dtype = dtype._np_dtype _check_device(device) + if isinstance(obj, Array) and device is None: + device = obj.device if np.__version__[0] < '2': if copy is False: @@ -158,6 +160,8 @@ def empty_like( _check_valid_dtype(dtype) _check_device(device) + if device is None: + device = x.device if dtype is not None: dtype = dtype._np_dtype @@ -260,6 +264,8 @@ def full_like( _check_valid_dtype(dtype) _check_device(device) + if device is None: + device = x.device if dtype is not None: dtype = dtype._np_dtype diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 819afad..71fd76b 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -23,7 +23,7 @@ zeros_like, ) from .._dtypes import float32, float64 -from .._array_object import Array, CPU_DEVICE +from .._array_object import Array, CPU_DEVICE, Device from .._flags import set_array_api_strict_flags def test_asarray_errors(): @@ -97,6 +97,17 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 0) + +def test_asarray_device_inference(): + assert asarray([1, 2, 3]).device == CPU_DEVICE + + x = asarray([1, 2, 3]) + assert asarray(x).device == CPU_DEVICE + + device1 = Device("device1") + x = asarray([1, 2, 3], device=device1) + assert asarray(x).device == device1 + def test_arange_errors(): arange(1, device=CPU_DEVICE) # Doesn't error assert_raises(ValueError, lambda: arange(1, device="cpu")) From f7152ff2ee1dbdc39fff856ff25f7825942719f4 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 21 Oct 2024 15:14:26 +0200 Subject: [PATCH 150/252] Use array's device when promoting scalars --- array_api_strict/_array_object.py | 2 +- array_api_strict/tests/test_array_object.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index cd9a360..34caff7 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -274,7 +274,7 @@ def _promote_scalar(self, scalar): # behavior for integers within the bounds of the integer dtype. # Outside of those bounds we use the default NumPy behavior (either # cast or raise OverflowError). - return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=CPU_DEVICE) + return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=self.device) @staticmethod def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 902c398..a9ea26d 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -6,7 +6,7 @@ import pytest from .. import ones, asarray, result_type, all, equal -from .._array_object import Array, CPU_DEVICE +from .._array_object import Array, CPU_DEVICE, Device from .._dtypes import ( _all_dtypes, _boolean_dtypes, @@ -88,6 +88,14 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[0]) assert_raises(IndexError, lambda: a[:]) +def test_promoted_scalar_inherits_device(): + device1 = Device("device1") + x = asarray([1., 2, 3], device=device1) + + y = x ** 2 + + assert y.device == device1 + def test_operators(): # For every operator, we test that it works for the required type # combinations and raises TypeError otherwise From a0161d0cabb78a17bec1d12c18efdd90f9ba59f6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 21 Oct 2024 14:36:11 -0600 Subject: [PATCH 151/252] Fix the definition of sign() for complex numbers This is correct in NumPy 2.0 but not in 1.x. --- array_api_strict/_elementwise_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 761caff..7dc6c5c 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -855,6 +855,8 @@ def sign(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sign") + if x.dtype in _complex_floating_dtypes: + return x/abs(x) return Array._new(np.sign(x._array), device=x.device) From 6c3b7d6c3baf8092ac5040d41a917f806a7aa30d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 28 Oct 2024 12:32:10 -0600 Subject: [PATCH 152/252] Temporarily enable __array__ in asarray so that parsing list of lists of Array can work --- array_api_strict/_array_object.py | 18 ++++++++++++++++++ array_api_strict/_creation_functions.py | 18 ++++++++++++++++-- .../tests/test_creation_functions.py | 15 ++++++++++++++- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index e7b53d9..f72b68e 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -53,6 +53,8 @@ def __repr__(self): _default = object() +_allow_array = False + class Array: """ n-d array object for the array API namespace. @@ -135,6 +137,22 @@ def __repr__(self: Array, /) -> str: # lead to code assuming np.asarray(other_array) would always work in the # standard. def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]: + # We have to allow this to be internally enabled as there's no other + # easy way to parse a list of Array objects in asarray(). + if _allow_array: + # copy keyword is new in 2.0.0; for older versions don't use it + # retry without that keyword. + if np.__version__[0] < '2': + return np.asarray(self._array, dtype=dtype) + elif np.__version__.startswith('2.0.0-dev0'): + # Handle dev version for which we can't know based on version + # number whether or not the copy keyword is supported. + try: + return np.asarray(self._array, dtype=dtype, copy=copy) + except TypeError: + return np.asarray(self._array, dtype=dtype) + else: + return np.asarray(self._array, dtype=dtype, copy=copy) raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported") # These are various helper functions to make the array behavior match the diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 67ba67c..52f9389 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -1,6 +1,6 @@ from __future__ import annotations - +from contextlib import contextmanager from typing import TYPE_CHECKING, List, Optional, Tuple, Union if TYPE_CHECKING: @@ -16,6 +16,19 @@ import numpy as np +@contextmanager +def allow_array(): + """ + Temporarily enable Array.__array__. This is needed for np.array to parse + list of lists of Array objects. + """ + from . import _array_object + original_value = _array_object._allow_array + try: + _array_object._allow_array = True + yield + finally: + _array_object._allow_array = original_value def _check_valid_dtype(dtype): # Note: Only spelling dtypes as the dtype objects is supported. @@ -94,7 +107,8 @@ def asarray( # Give a better error message in this case. NumPy would convert this # to an object array. TODO: This won't handle large integers in lists. raise OverflowError("Integer out of bounds for array dtypes") - res = np.array(obj, dtype=_np_dtype, copy=copy) + with allow_array(): + res = np.array(obj, dtype=_np_dtype, copy=copy) return Array._new(res) diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 819afad..bb486b1 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -22,7 +22,7 @@ zeros, zeros_like, ) -from .._dtypes import float32, float64 +from .._dtypes import int16, float32, float64 from .._array_object import Array, CPU_DEVICE from .._flags import set_array_api_strict_flags @@ -97,6 +97,19 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 0) +def test_asarray_list_of_lists(): + a = asarray(1, dtype=int16) + b = asarray([1], dtype=int16) + res = asarray([a, a]) + assert res.shape == (2,) + assert res.dtype == int16 + assert all(res == asarray([1, 1])) + + res = asarray([b, b]) + assert res.shape == (2, 1) + assert res.dtype == int16 + assert all(res == asarray([[1], [1]])) + def test_arange_errors(): arange(1, device=CPU_DEVICE) # Doesn't error assert_raises(ValueError, lambda: arange(1, device="cpu")) From bb2816745191892e2114bf3cb3c929ecee9c826f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 31 Oct 2024 13:51:48 -0600 Subject: [PATCH 153/252] Fix incorrect merge conflict resolution --- array_api_strict/_array_object.py | 4 ++-- array_api_strict/tests/test_array_object.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 8e12ef4..9416e38 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -160,8 +160,8 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None # We have to allow this to be internally enabled as there's no other # easy way to parse a list of Array objects in asarray(). if _allow_array: - if self._device != CPU_DEVICE: - raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") + if self._device != CPU_DEVICE: + raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") # copy keyword is new in 2.0.0; for older versions don't use it # retry without that keyword. if np.__version__[0] < '2': diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index c59aa54..c7781d7 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -353,14 +353,17 @@ def test_array_properties(): def test_array_conversion(): # Check that arrays on the CPU device can be converted to NumPy - # but arrays on other devices can't + # but arrays on other devices can't. Note this is testing the logic in + # __array__, which is only used in asarray when converting lists of + # arrays. a = ones((2, 3)) - np.asarray(a) + asarray([a]) for device in ("device1", "device2"): a = ones((2, 3), device=array_api_strict.Device(device)) with pytest.raises(RuntimeError, match="Can not convert array"): - np.asarray(a) + asarray([a]) + def test_allow_newaxis(): a = ones(5) indexed_a = a[None, :] From d630ee5f5a7c4474838c1782d694fe1733008d68 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 31 Oct 2024 14:57:10 -0600 Subject: [PATCH 154/252] Fix the pinv function, which was implicitly using __array__ --- array_api_strict/_linalg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_strict/_linalg.py b/array_api_strict/_linalg.py index d341277..7d379a0 100644 --- a/array_api_strict/_linalg.py +++ b/array_api_strict/_linalg.py @@ -267,6 +267,8 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: # default tolerance by max(M, N). if rtol is None: rtol = max(x.shape[-2:]) * finfo(x.dtype).eps + if isinstance(rtol, Array): + rtol = rtol._array return Array._new(np.linalg.pinv(x._array, rcond=rtol), device=x.device) @requires_extension('linalg') From cfeb761e5a01c3c47e007ba85de1e8904360849f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 1 Nov 2024 15:43:11 -0600 Subject: [PATCH 155/252] Use a more robust implementation of clip This is based on the implementation from the compat library. Fixes #75 Closes #49 --- array_api_strict/_elementwise_functions.py | 54 ++++++++++++++++++---- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 7dc6c5c..8cec86a 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -14,6 +14,7 @@ from ._array_object import Array from ._flags import requires_api_version from ._creation_functions import asarray +from ._data_type_functions import broadcast_to, iinfo from typing import Optional, Union @@ -325,14 +326,51 @@ def clip( if min is not None and max is not None and np.any(min > max): raise ValueError("min must be less than or equal to max") - result = np.clip(x._array, min, max) - # Note: NumPy applies type promotion, but the standard specifies the - # return dtype should be the same as x - if result.dtype != x.dtype._np_dtype: - # TODO: I'm not completely sure this always gives the correct thing - # for integer dtypes. See https://github.com/numpy/numpy/issues/24976 - result = result.astype(x.dtype._np_dtype) - return Array._new(result, device=x.device) + # np.clip does type promotion but the array API clip requires that the + # output have the same dtype as x. We do this instead of just downcasting + # the result of xp.clip() to handle some corner cases better (e.g., + # avoiding uint64 -> float64 promotion). + + # Note: cases where min or max overflow (integer) or round (float) in the + # wrong direction when downcasting to x.dtype are unspecified. This code + # just does whatever NumPy does when it downcasts in the assignment, but + # other behavior could be preferred, especially for integers. For example, + # this code produces: + + # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None) + # -128 + + # but an answer of 0 might be preferred. See + # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. + + # At least handle the case of Python integers correctly (see + # https://github.com/numpy/numpy/pull/26892). + if type(min) is int and min <= iinfo(x.dtype).min: + min = None + if type(max) is int and max >= iinfo(x.dtype).max: + max = None + + def _isscalar(a): + return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape + max_shape = () if _isscalar(max) else max.shape + + result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape) + + out = asarray(broadcast_to(x, result_shape), copy=True)._array + device = x.device + x = x._array + + if min is not None: + a = np.broadcast_to(np.asarray(min), result_shape) + ia = (out < a) | np.isnan(a) + + out[ia] = a[ia] + if max is not None: + b = np.broadcast_to(np.asarray(max), result_shape) + ib = (out > b) | np.isnan(b) + out[ib] = b[ib] + return Array._new(out, device=device) def conj(x: Array, /) -> Array: """ From 9530a7f97fa1370213ccccbb26e96bef3832d82d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 1 Nov 2024 18:02:10 -0600 Subject: [PATCH 156/252] Properly set the array API version on CI --- .github/workflows/array-api-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 36ef85c..9f168cb 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -49,6 +49,7 @@ jobs: run: | # Parameterizing this in the CI matrix is wasteful. Just do a loop here. for ARRAY_API_STRICT_API_VERSION in ${API_VERSIONS}; do + export ARRAY_API_STRICT_API_VERSION cd ${GITHUB_WORKSPACE}/array-api-tests pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/array-api-strict/array-api-tests-xfails.txt ${PYTEST_ARGS} done From 728c69ad274dfc263f8464bcc320af4e1a7895cb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 12:23:35 -0700 Subject: [PATCH 157/252] Support setting the version to the draft version of the standard --- array_api_strict/_flags.py | 11 ++++++++--- array_api_strict/tests/test_array_object.py | 6 +++++- array_api_strict/tests/test_flags.py | 15 +++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index c393ad9..b998f43 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -24,6 +24,8 @@ "2023.12", ) +draft_version = "2024.12" + API_VERSION = default_version = "2023.12" BOOLEAN_INDEXING = True @@ -70,8 +72,8 @@ def set_array_api_strict_flags( ---------- api_version : str, optional The version of the standard to use. Supported versions are: - ``{supported_versions}``. The default version number is - ``{default_version!r}``. + ``{supported_versions}``, plus the draft version ``{draft_version}``. + The default version number is ``{default_version!r}``. Note that 2021.12 is supported, but currently gives the same thing as 2022.12 (except that the fft extension will be disabled). @@ -134,10 +136,12 @@ def set_array_api_strict_flags( global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS if api_version is not None: - if api_version not in supported_versions: + if api_version not in [*supported_versions, draft_version]: raise ValueError(f"Unsupported standard version {api_version!r}") if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2) + if api_version == draft_version: + warnings.warn(f"The {draft_version} version of the array API specification is in draft status. Not all features are implemented in array_api_strict, and behaviors are subject to change before the final standard release.") API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION @@ -169,6 +173,7 @@ def set_array_api_strict_flags( supported_versions=supported_versions, default_version=default_version, default_extensions=default_extensions, + draft_version=draft_version, ) def get_array_api_strict_flags(): diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index a9ea26d..5b8dbff 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -440,8 +440,12 @@ def test_array_namespace(): assert a.__array_namespace__(api_version="2021.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2021.12" + with pytest.warns(UserWarning): + assert a.__array_namespace__(api_version="2024.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2024.12" + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) - pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2025.12")) def test_iter(): pytest.raises(TypeError, lambda: iter(asarray(3))) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 2603f35..712e464 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -99,6 +99,21 @@ def test_flags_api_version_2023_12(): 'enabled_extensions': ('linalg', 'fft'), } +def test_flags_api_version_2024_12(): + # Make sure setting the version to 2024.12 issues a warning. + with pytest.warns(UserWarning) as record: + set_array_api_strict_flags(api_version='2024.12') + assert len(record) == 1 + assert '2024.12' in str(record[0].message) + assert 'draft' in str(record[0].message) + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2024.12', + 'boolean_indexing': True, + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } + def test_setting_flags_invalid(): # Test setting flags with invalid values pytest.raises(ValueError, lambda: From 31ceaae44189480ba8d404da7b3c958911d19552 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 12:47:08 -0700 Subject: [PATCH 158/252] Add preliminary diff() function for 2024.12 --- array_api_strict/__init__.py | 4 ++-- array_api_strict/_utility_functions.py | 23 +++++++++++++++++++++++ array_api_strict/tests/test_flags.py | 26 ++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 2 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index ff43660..025133c 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -305,9 +305,9 @@ __all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] -from ._utility_functions import all, any +from ._utility_functions import all, any, diff -__all__ += ["all", "any"] +__all__ += ["all", "any", "diff"] from ._array_object import Device __all__ += ["Device"] diff --git a/array_api_strict/_utility_functions.py b/array_api_strict/_utility_functions.py index 0d44ecb..4cbea68 100644 --- a/array_api_strict/_utility_functions.py +++ b/array_api_strict/_utility_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._array_object import Array +from ._flags import requires_api_version from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -37,3 +38,25 @@ def any( See its docstring for more information. """ return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims)), device=x.device) + +@requires_api_version('2024.12') +def diff( + x: Array, + /, + *, + axis: int = -1, + n: int = 1, + prepend: Optional[Array] = None, + append: Optional[Array] = None, +) -> Array: + # NumPy does not support prepend=None or append=None + kwargs = dict(axis=axis, n=n) + if prepend is not None: + if prepend.device != x.device: + raise ValueError(f"Arrays from two different devices ({prepend.device} and {x.device}) can not be combined.") + kwargs['prepend'] = prepend._array + if append is not None: + if append.device != x.device: + raise ValueError(f"Arrays from two different devices ({append.device} and {x.device}) can not be combined.") + kwargs['append'] = append._array + return Array._new(np.diff(x._array, **kwargs), device=x.device) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 712e464..7fa6828 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -282,6 +282,10 @@ def test_fft(func_name): 'unstack': lambda: xp.unstack(xp.ones((3, 3)), axis=0), } +api_version_2024_12_examples = { + 'diff': lambda: xp.diff(xp.asarray([0, 1, 2])), +} + @pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys()) def test_api_version_2023_12(func_name): func = api_version_2023_12_examples[func_name] @@ -300,6 +304,28 @@ def test_api_version_2023_12(func_name): set_array_api_strict_flags(api_version='2022.12') pytest.raises(RuntimeError, func) +@pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys()) +def test_api_version_2024_12(func_name): + func = api_version_2024_12_examples[func_name] + + # By default, these functions should error + pytest.raises(RuntimeError, func) + + # In 2022.12 and 2023.12, these functions should error + set_array_api_strict_flags(api_version='2022.12') + pytest.raises(RuntimeError, func) + set_array_api_strict_flags(api_version='2023.12') + pytest.raises(RuntimeError, func) + + # They should not error in 2024.12 + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2024.12') + func() + + # Test the behavior gets updated properly + set_array_api_strict_flags(api_version='2023.12') + pytest.raises(RuntimeError, func) + def test_disabled_extensions(): # Test that xp.extension errors when an extension is disabled, and that # xp.__all__ is updated properly. From 729175f85619e1102a2dd372bf6662ab09df3778 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 12:47:52 -0700 Subject: [PATCH 159/252] Add warning that functions may not be fully tested --- array_api_strict/_flags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index b998f43..2863e5f 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -141,7 +141,7 @@ def set_array_api_strict_flags( if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2) if api_version == draft_version: - warnings.warn(f"The {draft_version} version of the array API specification is in draft status. Not all features are implemented in array_api_strict, and behaviors are subject to change before the final standard release.") + warnings.warn(f"The {draft_version} version of the array API specification is in draft status. Not all features are implemented in array_api_strict, some functions may not be fully tested, and behaviors are subject to change before the final standard release.") API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION From b2e3ecc4eaa62f496f7eeda2ca1a72c20bbb065c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 12:52:07 -0700 Subject: [PATCH 160/252] Require numeric types in diff --- array_api_strict/_utility_functions.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/array_api_strict/_utility_functions.py b/array_api_strict/_utility_functions.py index 4cbea68..f75f36f 100644 --- a/array_api_strict/_utility_functions.py +++ b/array_api_strict/_utility_functions.py @@ -2,6 +2,7 @@ from ._array_object import Array from ._flags import requires_api_version +from ._dtypes import _numeric_dtypes from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -49,6 +50,12 @@ def diff( prepend: Optional[Array] = None, append: Optional[Array] = None, ) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in diff") + + # TODO: The type promotion behavior for prepend and append is not + # currently specified. + # NumPy does not support prepend=None or append=None kwargs = dict(axis=axis, n=n) if prepend is not None: From 1d111b301164023f5f580be4139d1f8047f90d57 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 12:56:18 -0700 Subject: [PATCH 161/252] Add draft implementation for nextafter --- array_api_strict/__init__.py | 2 ++ array_api_strict/_elementwise_functions.py | 14 ++++++++++++++ .../tests/test_elementwise_functions.py | 9 +++++++-- array_api_strict/tests/test_flags.py | 1 + 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 025133c..8e6f9d7 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -172,6 +172,7 @@ minimum, multiply, negative, + nextafter, not_equal, positive, pow, @@ -240,6 +241,7 @@ "minimum", "multiply", "negative", + "nextafter", "not_equal", "positive", "pow", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 8cec86a..8daab5f 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -805,6 +805,20 @@ def negative(x: Array, /) -> Array: return Array._new(np.negative(x._array), device=x.device) +@requires_api_version('2024.12') +def nextafter(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.nextafter `. + + See its docstring for more information. + """ + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in nextafter") + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np.nextafter(x1._array, x2._array), device=x1.device) + def not_equal(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.not_equal `. diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index de11edf..7aa51b6 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -2,6 +2,8 @@ from numpy.testing import assert_raises +import pytest + from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift from .._dtypes import ( @@ -79,6 +81,7 @@ def nargs(func): "minimum": "real numeric", "multiply": "numeric", "negative": "numeric", + "nextafter": "real floating-point", "not_equal": "all", "positive": "numeric", "pow": "numeric", @@ -126,7 +129,8 @@ def _array_vals(dtypes): yield asarray(1., dtype=d) # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2023.12") + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2024.12") for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] @@ -162,7 +166,8 @@ def _array_vals(): yield asarray(1.0, dtype=d) # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2023.12") + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2024.12") for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 7fa6828..31d9ecd 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -284,6 +284,7 @@ def test_fft(func_name): api_version_2024_12_examples = { 'diff': lambda: xp.diff(xp.asarray([0, 1, 2])), + 'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)), } @pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys()) From 69d0ed2779068366c12f39338306be0e303b44cc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Nov 2024 21:30:44 +0000 Subject: [PATCH 162/252] Bump pypa/gh-action-pypi-publish in the actions group across 1 directory Bumps the actions group with 1 update in the / directory: [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish). Updates `pypa/gh-action-pypi-publish` from 1.9.0 to 1.11.0 - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.9.0...v1.11.0) --- updated-dependencies: - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-minor dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/publish-package.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 5ebe50c..feb1095 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -97,7 +97,7 @@ jobs: if: >- (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - uses: pypa/gh-action-pypi-publish@v1.9.0 + uses: pypa/gh-action-pypi-publish@v1.11.0 with: repository-url: https://test.pypi.org/legacy/ print-hash: true @@ -110,6 +110,6 @@ jobs: - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@v1.9.0 + uses: pypa/gh-action-pypi-publish@v1.11.0 with: print-hash: true From d9f43f4fa2160fc67c0e91b177e19ae4580a5420 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 14:35:55 -0700 Subject: [PATCH 163/252] Add 'max dimensions' to capabilities() for 2024.12 --- array_api_strict/_info.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index 3ed7fb2..4927e97 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING +import numpy as np + if TYPE_CHECKING: from typing import Optional, Union, Tuple, List from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info @@ -18,9 +20,23 @@ def __array_namespace_info__() -> Info: @requires_api_version('2023.12') def capabilities() -> Capabilities: flags = get_array_api_strict_flags() - return {"boolean indexing": flags['boolean_indexing'], + res = {"boolean indexing": flags['boolean_indexing'], "data-dependent shapes": flags['data_dependent_shapes'], } + if flags['api_version'] >= '2024.12': + # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will + # drop support for NumPy 1 but for now, just compute the number + # directly + for i in range(1, 100): + try: + np.zeros((1,)*i) + except ValueError: + maxdims = i - 1 + break + else: + raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)") + res['max dimensions'] = maxdims + return res @requires_api_version('2023.12') def default_device() -> device: From 632e895af7b95c686e32ac9731b3f4bff7bd573c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 14:38:28 -0700 Subject: [PATCH 164/252] Add max dimensions to the Capabilities typing dict I don't know how to make this depend on API version so for now it's just there always. --- array_api_strict/_typing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index 05a479c..8fdfeda 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -54,7 +54,8 @@ class SupportsDLPack(Protocol): def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... Capabilities = TypedDict( - "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool} + "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool, + "max dimensions": int} ) DefaultDataTypes = TypedDict( From 548f07174a26eef5b5b8501a4cba49412e90dc06 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 14:43:31 -0700 Subject: [PATCH 165/252] Make __array_namespace_info__ a class This makes it so that it doesn't have a bunch of extra names on it, which it did as a module. --- array_api_strict/_info.py | 247 +++++++++++++-------------- array_api_strict/_typing.py | 3 +- array_api_strict/tests/test_flags.py | 16 +- 3 files changed, 130 insertions(+), 136 deletions(-) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index 4927e97..f288d2e 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -6,143 +6,134 @@ if TYPE_CHECKING: from typing import Optional, Union, Tuple, List - from ._typing import device, DefaultDataTypes, DataTypes, Capabilities, Info + from ._typing import device, DefaultDataTypes, DataTypes, Capabilities from ._array_object import ALL_DEVICES, CPU_DEVICE from ._flags import get_array_api_strict_flags, requires_api_version from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 @requires_api_version('2023.12') -def __array_namespace_info__() -> Info: - import array_api_strict._info - return array_api_strict._info - -@requires_api_version('2023.12') -def capabilities() -> Capabilities: - flags = get_array_api_strict_flags() - res = {"boolean indexing": flags['boolean_indexing'], - "data-dependent shapes": flags['data_dependent_shapes'], - } - if flags['api_version'] >= '2024.12': - # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will - # drop support for NumPy 1 but for now, just compute the number - # directly - for i in range(1, 100): - try: - np.zeros((1,)*i) - except ValueError: - maxdims = i - 1 - break - else: - raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)") - res['max dimensions'] = maxdims - return res - -@requires_api_version('2023.12') -def default_device() -> device: - return CPU_DEVICE +class __array_namespace_info__: + @requires_api_version('2023.12') + def capabilities(self) -> Capabilities: + flags = get_array_api_strict_flags() + res = {"boolean indexing": flags['boolean_indexing'], + "data-dependent shapes": flags['data_dependent_shapes'], + } + if flags['api_version'] >= '2024.12': + # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will + # drop support for NumPy 1 but for now, just compute the number + # directly + for i in range(1, 100): + try: + np.zeros((1,)*i) + except ValueError: + maxdims = i - 1 + break + else: + raise RuntimeError("Could not get max dimensions (this is a bug in array-api-strict)") + res['max dimensions'] = maxdims + return res -@requires_api_version('2023.12') -def default_dtypes( - *, - device: Optional[device] = None, -) -> DefaultDataTypes: - return { - "real floating": float64, - "complex floating": complex128, - "integral": int64, - "indexing": int64, - } + @requires_api_version('2023.12') + def default_device(self) -> device: + return CPU_DEVICE -@requires_api_version('2023.12') -def dtypes( - *, - device: Optional[device] = None, - kind: Optional[Union[str, Tuple[str, ...]]] = None, -) -> DataTypes: - if kind is None: + @requires_api_version('2023.12') + def default_dtypes( + self, + *, + device: Optional[device] = None, + ) -> DefaultDataTypes: return { - "bool": bool, - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - "float32": float32, - "float64": float64, - "complex64": complex64, - "complex128": complex128, + "real floating": float64, + "complex floating": complex128, + "integral": int64, + "indexing": int64, } - if kind == "bool": - return {"bool": bool} - if kind == "signed integer": - return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - } - if kind == "unsigned integer": - return { - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - } - if kind == "integral": - return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - } - if kind == "real floating": - return { - "float32": float32, - "float64": float64, - } - if kind == "complex floating": - return { - "complex64": complex64, - "complex128": complex128, - } - if kind == "numeric": - return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - "float32": float32, - "float64": float64, - "complex64": complex64, - "complex128": complex128, - } - if isinstance(kind, tuple): - res = {} - for k in kind: - res.update(dtypes(kind=k)) - return res - raise ValueError(f"unsupported kind: {kind!r}") -@requires_api_version('2023.12') -def devices() -> List[device]: - return list(ALL_DEVICES) + @requires_api_version('2023.12') + def dtypes( + self, + *, + device: Optional[device] = None, + kind: Optional[Union[str, Tuple[str, ...]]] = None, + ) -> DataTypes: + if kind is None: + return { + "bool": bool, + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + } + if kind == "unsigned integer": + return { + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "integral": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "real floating": + return { + "float32": float32, + "float64": float64, + } + if kind == "complex floating": + return { + "complex64": complex64, + "complex128": complex128, + } + if kind == "numeric": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") -__all__ = [ - "capabilities", - "default_device", - "default_dtypes", - "devices", - "dtypes", -] + @requires_api_version('2023.12') + def devices(self) -> List[device]: + return list(ALL_DEVICES) diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index 8fdfeda..f13fdcf 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -29,6 +29,7 @@ from ._array_object import Array, _device from ._dtypes import _DType +from ._info import __array_namespace_info__ _T_co = TypeVar("_T_co", covariant=True) @@ -41,7 +42,7 @@ def __len__(self, /) -> int: ... Dtype = _DType -Info = ModuleType +Info = __array_namespace_info__ if sys.version_info >= (3, 12): from collections.abc import Buffer as SupportsBufferProtocol diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 31d9ecd..b6e544e 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -3,8 +3,7 @@ from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, reset_array_api_strict_flags) -from .._info import (capabilities, default_device, default_dtypes, devices, - dtypes) +from .._info import __array_namespace_info__ from .._fft import (fft, ifft, fftn, ifftn, rfft, irfft, rfftn, irfftn, hfft, ihfft, fftfreq, rfftfreq, fftshift, ifftshift) from .._linalg import (cholesky, cross, det, diagonal, eigh, eigvalsh, inv, @@ -260,14 +259,17 @@ def test_fft(func_name): set_array_api_strict_flags(enabled_extensions=('fft',)) func() +# Test functionality even if the info object is already created +_info = xp.__array_namespace_info__() + api_version_2023_12_examples = { '__array_namespace_info__': lambda: xp.__array_namespace_info__(), # Test these functions directly to ensure they are properly decorated - 'capabilities': capabilities, - 'default_device': default_device, - 'default_dtypes': default_dtypes, - 'devices': devices, - 'dtypes': dtypes, + 'capabilities': _info.capabilities, + 'default_device': _info.default_device, + 'default_dtypes': _info.default_dtypes, + 'devices': _info.devices, + 'dtypes': _info.dtypes, 'clip': lambda: xp.clip(xp.asarray([1, 2, 3]), 1, 2), 'copysign': lambda: xp.copysign(xp.asarray([1., 2., 3.]), xp.asarray([-1., -1., -1.])), 'cumulative_sum': lambda: xp.cumulative_sum(xp.asarray([1, 2, 3])), From 9726bc096abc04c63aa804634025ad24d0aeba75 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 14:48:10 -0700 Subject: [PATCH 166/252] Add a draft reciprocal function for 2024.12 --- array_api_strict/__init__.py | 2 ++ array_api_strict/_elementwise_functions.py | 11 +++++++++++ array_api_strict/tests/test_elementwise_functions.py | 1 + array_api_strict/tests/test_flags.py | 1 + 4 files changed, 15 insertions(+) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 8e6f9d7..c8c2fa6 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -177,6 +177,7 @@ positive, pow, real, + reciprocal, remainder, round, sign, @@ -246,6 +247,7 @@ "positive", "pow", "real", + "reciprocal", "remainder", "round", "sign", diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 8daab5f..7c64f67 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -872,6 +872,17 @@ def real(x: Array, /) -> Array: return Array._new(np.real(x._array), device=x.device) +@requires_api_version('2024.12') +def reciprocal(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.reciprocal `. + + See its docstring for more information. + """ + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in reciprocal") + return Array._new(np.reciprocal(x._array), device=x.device) + def remainder(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.remainder `. diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 7aa51b6..4e1b9cc 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -86,6 +86,7 @@ def nargs(func): "positive": "numeric", "pow": "numeric", "real": "complex floating-point", + "reciprocal": "floating-point", "remainder": "real numeric", "round": "numeric", "sign": "numeric", diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index b6e544e..43139d1 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -287,6 +287,7 @@ def test_fft(func_name): api_version_2024_12_examples = { 'diff': lambda: xp.diff(xp.asarray([0, 1, 2])), 'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)), + 'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])), } @pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys()) From 43d60b520da6235933ef568b4c5e27ccc1dedd59 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 4 Nov 2024 14:56:43 -0700 Subject: [PATCH 167/252] Add draft implementation of take_along_axis for 2024.12 As far as I can tell, NumPy matches the standard specification, except for the fact that NumPy does not set a default value for axis. --- array_api_strict/__init__.py | 4 ++-- array_api_strict/_indexing_functions.py | 12 ++++++++++++ array_api_strict/tests/test_flags.py | 14 ++++++++------ 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index c8c2fa6..98b0e95 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -262,9 +262,9 @@ "trunc", ] -from ._indexing_functions import take +from ._indexing_functions import take, take_along_axis -__all__ += ["take"] +__all__ += ["take", "take_along_axis"] from ._info import __array_namespace_info__ diff --git a/array_api_strict/_indexing_functions.py b/array_api_strict/_indexing_functions.py index c0f8e26..d7a400e 100644 --- a/array_api_strict/_indexing_functions.py +++ b/array_api_strict/_indexing_functions.py @@ -2,6 +2,7 @@ from ._array_object import Array from ._dtypes import _integer_dtypes +from ._flags import requires_api_version from typing import TYPE_CHECKING @@ -25,3 +26,14 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: if x.device != indices.device: raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device) + +@requires_api_version('2024.12') +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: + """ + Array API compatible wrapper for :py:func:`np.take_along_axis `. + + See its docstring for more information. + """ + if x.device != indices.device: + raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") + return Array._new(np.take_along_axis(x._array, indices._array, axis), device=x.device) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 43139d1..a69a1ed 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -284,12 +284,6 @@ def test_fft(func_name): 'unstack': lambda: xp.unstack(xp.ones((3, 3)), axis=0), } -api_version_2024_12_examples = { - 'diff': lambda: xp.diff(xp.asarray([0, 1, 2])), - 'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)), - 'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])), -} - @pytest.mark.parametrize('func_name', api_version_2023_12_examples.keys()) def test_api_version_2023_12(func_name): func = api_version_2023_12_examples[func_name] @@ -308,6 +302,14 @@ def test_api_version_2023_12(func_name): set_array_api_strict_flags(api_version='2022.12') pytest.raises(RuntimeError, func) +api_version_2024_12_examples = { + 'diff': lambda: xp.diff(xp.asarray([0, 1, 2])), + 'nextafter': lambda: xp.nextafter(xp.asarray(0.), xp.asarray(1.)), + 'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])), + 'take_along_axis': lambda: xp.take_along_axis(xp.zeros((2, 3)), + xp.zeros((1, 4), dtype=xp.int64)), +} + @pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys()) def test_api_version_2024_12(func_name): func = api_version_2024_12_examples[func_name] From 61b3c90c8885e1fa6a8ec807022ec2ba357e2e72 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 5 Nov 2024 14:22:40 -0700 Subject: [PATCH 168/252] Fix ruff issues --- array_api_strict/_info.py | 2 +- array_api_strict/_typing.py | 1 - array_api_strict/tests/test_flags.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index f288d2e..a9dbebf 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -130,7 +130,7 @@ def dtypes( if isinstance(kind, tuple): res = {} for k in kind: - res.update(dtypes(kind=k)) + res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index f13fdcf..94c4975 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -21,7 +21,6 @@ from typing import ( Any, - ModuleType, TypedDict, TypeVar, Protocol, diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index a69a1ed..e0b004b 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -3,7 +3,6 @@ from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, reset_array_api_strict_flags) -from .._info import __array_namespace_info__ from .._fft import (fft, ifft, fftn, ifftn, rfft, irfft, rfftn, irfftn, hfft, ihfft, fftfreq, rfftfreq, fftshift, ifftshift) from .._linalg import (cholesky, cross, det, diagonal, eigh, eigvalsh, inv, From 894cfa00b363734b53f55b3537ad85614ba86917 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 6 Nov 2024 15:25:52 -0700 Subject: [PATCH 169/252] Add changelog entries for 2.1.1 --- docs/changelog.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/changelog.md b/docs/changelog.md index 1329d32..d62653a 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,9 +1,22 @@ # Changelog +## 2.1.1 (2024-11-06) + +### Minor Changes + +- Use a more robust implementation of `clip()` that handles corner cases better. + +- Fix the definition of `sign()` for complex numbers when using NumPy 1.x. + +- Correctly use the array's device when promoting scalars. (Thanks to + [@betatim](https://github.com/betatim)) + +- Correctly propagate the input array's device in `asarray()`. (Thanks to + [@betatim](https://github.com/betatim)) ## 2.1 (2024-10-18) -## Major Changes +### Major Changes - The default version of the array API standard is now 2023.12. 2022.12 can still be enabled via the [flags API](array-api-strict-flags). From d9f7fa6891acb71a9d600acb66aba3e1a505c613 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 6 Nov 2024 15:34:13 -0700 Subject: [PATCH 170/252] Add a changelog entry for the removal of `__array__` --- docs/changelog.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index d62653a..9a451d3 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,14 @@ ## 2.1.1 (2024-11-06) +### Major Changes + +- Remove the `__array__` method from array-api-strict arrays. This means they + will no longer be implicitly converted to NumPy arrays when passed to `np` + functions. This method was previously implemented as a convenience, but it + isn't part of the array API standard. To portably convert an array API + strict array to a NumPy array, use `np.from_dlpack(x)` + ### Minor Changes - Use a more robust implementation of `clip()` that handles corner cases better. From 62ff67543b1204425a4881647f504d0ad88ec295 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 Nov 2024 23:43:35 +0000 Subject: [PATCH 171/252] Bump pypa/gh-action-pypi-publish in the actions group Bumps the actions group with 1 update: [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish). Updates `pypa/gh-action-pypi-publish` from 1.11.0 to 1.12.2 - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.11.0...v1.12.2) --- updated-dependencies: - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-minor dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/publish-package.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index feb1095..b98aad0 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -97,7 +97,7 @@ jobs: if: >- (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - uses: pypa/gh-action-pypi-publish@v1.11.0 + uses: pypa/gh-action-pypi-publish@v1.12.2 with: repository-url: https://test.pypi.org/legacy/ print-hash: true @@ -110,6 +110,6 @@ jobs: - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@v1.11.0 + uses: pypa/gh-action-pypi-publish@v1.12.2 with: print-hash: true From 00536a97d0e91b5e11424626c7c58afbfd22fa67 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 6 Nov 2024 16:43:48 -0700 Subject: [PATCH 172/252] Create the GitHub release after publishing to PyPI --- .github/workflows/publish-package.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index feb1095..fbdb8ca 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -102,14 +102,14 @@ jobs: repository-url: https://test.pypi.org/legacy/ print-hash: true - - name: Create GitHub Release from a Tag - uses: softprops/action-gh-release@v2 - if: startsWith(github.ref, 'refs/tags/') - with: - files: dist/* - - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') uses: pypa/gh-action-pypi-publish@v1.11.0 with: print-hash: true + + - name: Create GitHub Release from a Tag + uses: softprops/action-gh-release@v2 + if: startsWith(github.ref, 'refs/tags/') + with: + files: dist/* From 35088135861f6db02f8353475b1c1e61ec23d192 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 6 Nov 2024 16:53:26 -0700 Subject: [PATCH 173/252] Add changelog for 2.1.2 --- docs/changelog.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index 9a451d3..e7a7f6b 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,10 @@ # Changelog +## 2.1.2 (2024-11-06) + +2.1.2 is identical to 2.1.1 and exists to fix issues with the PyPI publishing +that occured with the 2.1.1 release. + ## 2.1.1 (2024-11-06) ### Major Changes From 2ba54cf2288f6b36523a67a6bea9a715d7441c5d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 13:17:03 -0700 Subject: [PATCH 174/252] Comment out TestPyPI publishing It's currently broken (https://github.com/pypa/gh-action-pypi-publish/issues/283) --- .github/workflows/publish-package.yml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index bddb802..825b0a0 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -91,16 +91,16 @@ jobs: - name: List all files run: ls -lh dist - - name: Publish distribution 📦 to Test PyPI - # Publish to TestPyPI on tag events of if manually triggered - # Compare to 'true' string as booleans get turned into strings in the console - if: >- - (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) - || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - uses: pypa/gh-action-pypi-publish@v1.12.2 - with: - repository-url: https://test.pypi.org/legacy/ - print-hash: true + # - name: Publish distribution 📦 to Test PyPI + # # Publish to TestPyPI on tag events of if manually triggered + # # Compare to 'true' string as booleans get turned into strings in the console + # if: >- + # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) + # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') + # uses: pypa/gh-action-pypi-publish@v1.12.2 + # with: + # repository-url: https://test.pypi.org/legacy/ + # print-hash: true - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') From 6e59494cbfd3ed949cf2009f778fc4cfdfc3eb04 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 13:18:01 -0700 Subject: [PATCH 175/252] Remove the 2.1.2 changelog 2.1.1 was never released to PyPI so I should be able to just do that one. --- docs/changelog.md | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index e7a7f6b..66f61d9 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,11 +1,6 @@ # Changelog -## 2.1.2 (2024-11-06) - -2.1.2 is identical to 2.1.1 and exists to fix issues with the PyPI publishing -that occured with the 2.1.1 release. - -## 2.1.1 (2024-11-06) +## 2.1.1 (2024-11-07) ### Major Changes From 2d3ebd9b4ef1df51535605d54a9c47ae026b4e50 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 13:21:07 -0700 Subject: [PATCH 176/252] Revert "Comment out TestPyPI publishing" This reverts commit 2ba54cf2288f6b36523a67a6bea9a715d7441c5d. --- .github/workflows/publish-package.yml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 825b0a0..bddb802 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -91,16 +91,16 @@ jobs: - name: List all files run: ls -lh dist - # - name: Publish distribution 📦 to Test PyPI - # # Publish to TestPyPI on tag events of if manually triggered - # # Compare to 'true' string as booleans get turned into strings in the console - # if: >- - # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) - # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - # uses: pypa/gh-action-pypi-publish@v1.12.2 - # with: - # repository-url: https://test.pypi.org/legacy/ - # print-hash: true + - name: Publish distribution 📦 to Test PyPI + # Publish to TestPyPI on tag events of if manually triggered + # Compare to 'true' string as booleans get turned into strings in the console + if: >- + (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) + || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') + uses: pypa/gh-action-pypi-publish@v1.12.2 + with: + repository-url: https://test.pypi.org/legacy/ + print-hash: true - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') From 4c80f6d8ccfcf2c9dbe5bdc2d3812ed1c9114525 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 16:10:18 -0700 Subject: [PATCH 177/252] Allow __dlpack__ to work with newer versions of NumPy --- array_api_strict/_array_object.py | 19 ++++++++++--------- array_api_strict/tests/test_array_object.py | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 9416e38..fa9cce8 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -586,15 +586,16 @@ def __dlpack__( if copy is not _default: raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") - # Going to wait for upstream numpy support - if max_version not in [_default, None]: - raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") - if dl_device not in [_default, None]: - raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") - if copy not in [_default, None]: - raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") - - return self._array.__dlpack__(stream=stream) + if np.__version__ < '2.1': + if max_version not in [_default, None]: + raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") + if dl_device not in [_default, None]: + raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") + if copy not in [_default, None]: + raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") + + return self._array.__dlpack__(stream=stream) + return self._array.__dlpack__(stream=stream, max_version=max_version, dl_device=dl_device, copy=copy) def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: """ diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index c7781d7..aea24da 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -460,7 +460,7 @@ def dlpack_2023_12(api_version): a.__dlpack__() - exception = NotImplementedError if api_version >= '2023.12' else ValueError + exception = NotImplementedError if api_version >= '2023.12' and np.__version__ < '2.1' else ValueError pytest.raises(exception, lambda: a.__dlpack__(dl_device=CPU_DEVICE)) pytest.raises(exception, lambda: From fa71e9e5f0b81f9413da4f9581908e67a2971b07 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 16:14:38 -0700 Subject: [PATCH 178/252] Fix version check --- array_api_strict/_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index fa9cce8..03993ab 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -586,7 +586,7 @@ def __dlpack__( if copy is not _default: raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") - if np.__version__ < '2.1': + if np.__version__[0] < '2.1': if max_version not in [_default, None]: raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") if dl_device not in [_default, None]: From 67d9667ba1ae883c55235043e4e13a80c23ff4e0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 16:14:43 -0700 Subject: [PATCH 179/252] Remove unused import --- array_api_strict/_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 03993ab..faea86c 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -38,7 +38,7 @@ if TYPE_CHECKING: from typing import Optional, Tuple, Union, Any - from ._typing import PyCapsule, Device, Dtype + from ._typing import PyCapsule, Dtype import numpy.typing as npt import numpy as np From 93201f134463a2fa56bebecc349dd67d4dc3d49f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 16:23:16 -0700 Subject: [PATCH 180/252] Fix passing of keyword arguments in __dlpack__ for NumPy 2.1 --- array_api_strict/_array_object.py | 10 +++++- array_api_strict/tests/test_array_object.py | 39 +++++++++++++-------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index faea86c..53669d1 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -595,7 +595,15 @@ def __dlpack__( raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") return self._array.__dlpack__(stream=stream) - return self._array.__dlpack__(stream=stream, max_version=max_version, dl_device=dl_device, copy=copy) + else: + kwargs = {'stream': stream} + if max_version is not _default: + kwargs['max_version'] = max_version + if dl_device is not _default: + kwargs['dl_device'] = dl_device + if copy is not _default: + kwargs['copy'] = copy + return self._array.__dlpack__(**kwargs) def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: """ diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index aea24da..96fd31e 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -460,18 +460,27 @@ def dlpack_2023_12(api_version): a.__dlpack__() - exception = NotImplementedError if api_version >= '2023.12' and np.__version__ < '2.1' else ValueError - pytest.raises(exception, lambda: - a.__dlpack__(dl_device=CPU_DEVICE)) - pytest.raises(exception, lambda: - a.__dlpack__(dl_device=None)) - pytest.raises(exception, lambda: - a.__dlpack__(max_version=(1, 0))) - pytest.raises(exception, lambda: - a.__dlpack__(max_version=None)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=False)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=True)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=None)) + if np.__version__ < '2.1': + exception = NotImplementedError if api_version >= '2023.12' else ValueError + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=CPU_DEVICE)) + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=None)) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=(1, 0))) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=None)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=False)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=True)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=None)) + else: + a.__dlpack__(dl_device=CPU_DEVICE) + a.__dlpack__(dl_device=None) + a.__dlpack__(max_version=(1, 0)) + a.__dlpack__(max_version=None) + a.__dlpack__(copy=False) + a.__dlpack__(copy=True) + a.__dlpack__(copy=None) From 6b4223e3e03c1cb3df63ec41bb4c897a70eab15b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 17:46:03 -0700 Subject: [PATCH 181/252] Require NumPy >= 2.1 Fixes #21 --- .github/workflows/array-api-tests.yml | 7 +--- .github/workflows/tests.yml | 7 +--- array_api_strict/__init__.py | 6 +++ array_api_strict/_array_object.py | 40 +++++--------------- array_api_strict/_creation_functions.py | 23 ----------- array_api_strict/tests/test_array_object.py | 42 ++++++++------------- requirements-dev.txt | 2 +- requirements.txt | 2 +- setup.py | 2 +- 9 files changed, 37 insertions(+), 94 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 9f168cb..cbfe5a2 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -12,10 +12,7 @@ jobs: strategy: matrix: python-version: ['3.9', '3.10', '3.11', '3.12'] - numpy-version: ['1.26', 'dev'] - exclude: - - python-version: '3.8' - numpy-version: 'dev' + numpy-version: ['2.1', 'dev'] steps: - name: Checkout array-api-strict @@ -38,7 +35,7 @@ jobs: if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy; else - python -m pip install 'numpy>=1.26,<2.0'; + python -m pip install 'numpy==${{ matrix.numpy-version }}'; fi python -m pip install ${GITHUB_WORKSPACE}/array-api-strict python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d8124d4..312f9cd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -6,10 +6,7 @@ jobs: strategy: matrix: python-version: ['3.9', '3.10', '3.11', '3.12'] - numpy-version: ['1.26', 'dev'] - exclude: - - python-version: '3.8' - numpy-version: 'dev' + numpy-version: ['2.1', 'dev'] fail-fast: true steps: - uses: actions/checkout@v4 @@ -22,7 +19,7 @@ jobs: if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy; else - python -m pip install 'numpy>=1.26,<2.0'; + python -m pip install 'numpy==${{ matrix.numpy-version }}'; fi python -m pip install -r requirements-dev.txt - name: Run Tests diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index ff43660..cbda499 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -16,6 +16,12 @@ """ +import numpy as np +from numpy.lib import NumpyVersion + +if NumpyVersion(np.__version__) < NumpyVersion('2.1.0'): + raise ImportError("array-api-strict requires NumPy >= 2.1.0") + __all__ = [] # Warning: __array_api_version__ could change globally with diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 53669d1..76cdfac 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -162,19 +162,7 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None if _allow_array: if self._device != CPU_DEVICE: raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") - # copy keyword is new in 2.0.0; for older versions don't use it - # retry without that keyword. - if np.__version__[0] < '2': - return np.asarray(self._array, dtype=dtype) - elif np.__version__.startswith('2.0.0-dev0'): - # Handle dev version for which we can't know based on version - # number whether or not the copy keyword is supported. - try: - return np.asarray(self._array, dtype=dtype, copy=copy) - except TypeError: - return np.asarray(self._array, dtype=dtype) - else: - return np.asarray(self._array, dtype=dtype, copy=copy) + return np.asarray(self._array, dtype=dtype, copy=copy) raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported") # These are various helper functions to make the array behavior match the @@ -586,24 +574,14 @@ def __dlpack__( if copy is not _default: raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") - if np.__version__[0] < '2.1': - if max_version not in [_default, None]: - raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") - if dl_device not in [_default, None]: - raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") - if copy not in [_default, None]: - raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") - - return self._array.__dlpack__(stream=stream) - else: - kwargs = {'stream': stream} - if max_version is not _default: - kwargs['max_version'] = max_version - if dl_device is not _default: - kwargs['dl_device'] = dl_device - if copy is not _default: - kwargs['copy'] = copy - return self._array.__dlpack__(**kwargs) + kwargs = {'stream': stream} + if max_version is not _default: + kwargs['max_version'] = max_version + if dl_device is not _default: + kwargs['dl_device'] = dl_device + if copy is not _default: + kwargs['copy'] = copy + return self._array.__dlpack__(**kwargs) def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: """ diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index d6d3efa..8d7705b 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -83,29 +83,6 @@ def asarray( if isinstance(obj, Array) and device is None: device = obj.device - if np.__version__[0] < '2': - if copy is False: - # Note: copy=False is not yet implemented in np.asarray for - # NumPy 1 - - # Work around it by creating the new array and seeing if NumPy - # copies it. - if isinstance(obj, Array): - new_array = np.array(obj._array, copy=copy, dtype=_np_dtype) - if new_array is not obj._array: - raise ValueError("Unable to avoid copy while creating an array from given array.") - return Array._new(new_array, device=device) - elif _supports_buffer_protocol(obj): - # Buffer protocol will always support no-copy - return Array._new(np.array(obj, copy=copy, dtype=_np_dtype), device=device) - else: - # No-copy is unsupported for Python built-in types. - raise ValueError("Unable to avoid copy while creating an array from given object.") - - if copy is None: - # NumPy 1 treats copy=False the same as the standard copy=None - copy = False - if isinstance(obj, Array): return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device) if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)): diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 96fd31e..0480f00 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -456,31 +456,19 @@ def dlpack_2023_12(api_version): set_array_api_strict_flags(api_version=api_version) a = asarray([1, 2, 3], dtype=int8) - # Never an error - a.__dlpack__() - - if np.__version__ < '2.1': - exception = NotImplementedError if api_version >= '2023.12' else ValueError - pytest.raises(exception, lambda: - a.__dlpack__(dl_device=CPU_DEVICE)) - pytest.raises(exception, lambda: - a.__dlpack__(dl_device=None)) - pytest.raises(exception, lambda: - a.__dlpack__(max_version=(1, 0))) - pytest.raises(exception, lambda: - a.__dlpack__(max_version=None)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=False)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=True)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=None)) - else: - a.__dlpack__(dl_device=CPU_DEVICE) - a.__dlpack__(dl_device=None) - a.__dlpack__(max_version=(1, 0)) - a.__dlpack__(max_version=None) - a.__dlpack__(copy=False) - a.__dlpack__(copy=True) - a.__dlpack__(copy=None) + # Do not error + a.__dlpack__() + a.__dlpack__(dl_device=CPU_DEVICE) + a.__dlpack__(dl_device=None) + a.__dlpack__(max_version=(1, 0)) + a.__dlpack__(max_version=None) + a.__dlpack__(copy=False) + a.__dlpack__(copy=True) + a.__dlpack__(copy=None) + + x = np.from_dlpack(a) + assert isinstance(x, np.ndarray) + assert x.dtype == np.int8 + assert x.shape == (3,) + assert np.all(x == np.asarray([1, 2, 3])) diff --git a/requirements-dev.txt b/requirements-dev.txt index 137e973..62673a8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,3 @@ pytest hypothesis -numpy +numpy>=2.1 diff --git a/requirements.txt b/requirements.txt index 24ce15a..92ce2ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -numpy +numpy>=2.1 diff --git a/setup.py b/setup.py index 29a94df..21d7417 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ url="https://data-apis.org/array-api-strict/", license="MIT", python_requires=">=3.9", - install_requires=["numpy"], + install_requires=["numpy>=2.1"], classifiers=[ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", From 0129dc4541f0c1d6d462abff92995c805319533b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 17:50:31 -0700 Subject: [PATCH 182/252] Fix installation of numpy 2.1 on CI --- .github/workflows/array-api-tests.yml | 2 +- .github/workflows/tests.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index cbfe5a2..452b156 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -35,7 +35,7 @@ jobs: if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy; else - python -m pip install 'numpy==${{ matrix.numpy-version }}'; + python -m pip install 'numpy>=${{ matrix.numpy-version }},<${{ matrix.numpy-version }}.99'; fi python -m pip install ${GITHUB_WORKSPACE}/array-api-strict python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 312f9cd..e88b5ac 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,7 +19,7 @@ jobs: if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy; else - python -m pip install 'numpy==${{ matrix.numpy-version }}'; + python -m pip install 'numpy>=${{ matrix.numpy-version }},<${{ matrix.numpy-version }}.99'; fi python -m pip install -r requirements-dev.txt - name: Run Tests From 54ce945f772ffbf773af59bd957cc7c920409bfd Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 17:56:18 -0700 Subject: [PATCH 183/252] NumPy 2.1 requires Python >= 3.10 --- .github/workflows/array-api-tests.yml | 2 +- .github/workflows/tests.yml | 2 +- setup.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 452b156..4386cde 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] + python-version: ['3.10', '3.11', '3.12'] numpy-version: ['2.1', 'dev'] steps: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e88b5ac..a03c045 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -5,7 +5,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] + python-version: ['3.10', '3.11', '3.12'] numpy-version: ['2.1', 'dev'] fail-fast: true steps: diff --git a/setup.py b/setup.py index 21d7417..89e85e5 100644 --- a/setup.py +++ b/setup.py @@ -15,14 +15,15 @@ long_description_content_type="text/markdown", url="https://data-apis.org/array-api-strict/", license="MIT", - python_requires=">=3.9", + python_requires=">=3.10", install_requires=["numpy>=2.1"], classifiers=[ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Intended Audience :: Developers", "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", ], From 2f411ee4e3038f412c0478a6c78da9043b944980 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 18:03:00 -0700 Subject: [PATCH 184/252] Add changelog for 2.1.2 release --- docs/changelog.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index 66f61d9..fe72889 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,11 @@ # Changelog +## 2.1.2 (2024-11-07) + +## Major Changes + +- array-api-strict now requires NumPy >= 2.1 and Python >= 3.10 + ## 2.1.1 (2024-11-07) ### Major Changes From c9fe697bec8de8e402d8beb71bbeb96ca70c772e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 7 Nov 2024 18:09:02 -0700 Subject: [PATCH 185/252] Revert "Revert "Comment out TestPyPI publishing"" This reverts commit 2d3ebd9b4ef1df51535605d54a9c47ae026b4e50. --- .github/workflows/publish-package.yml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index bddb802..825b0a0 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -91,16 +91,16 @@ jobs: - name: List all files run: ls -lh dist - - name: Publish distribution 📦 to Test PyPI - # Publish to TestPyPI on tag events of if manually triggered - # Compare to 'true' string as booleans get turned into strings in the console - if: >- - (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) - || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - uses: pypa/gh-action-pypi-publish@v1.12.2 - with: - repository-url: https://test.pypi.org/legacy/ - print-hash: true + # - name: Publish distribution 📦 to Test PyPI + # # Publish to TestPyPI on tag events of if manually triggered + # # Compare to 'true' string as booleans get turned into strings in the console + # if: >- + # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) + # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') + # uses: pypa/gh-action-pypi-publish@v1.12.2 + # with: + # repository-url: https://test.pypi.org/legacy/ + # print-hash: true - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') From ca387b294197cab8b17e2dfb70d8d9a25fb05caf Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 8 Nov 2024 13:39:08 -0700 Subject: [PATCH 186/252] Revert "Require NumPy >= 2.1" --- .github/workflows/array-api-tests.yml | 9 +++-- .github/workflows/tests.yml | 9 +++-- array_api_strict/__init__.py | 6 --- array_api_strict/_array_object.py | 40 +++++++++++++++----- array_api_strict/_creation_functions.py | 23 +++++++++++ array_api_strict/tests/test_array_object.py | 42 +++++++++++++-------- requirements-dev.txt | 2 +- requirements.txt | 2 +- setup.py | 7 ++-- 9 files changed, 98 insertions(+), 42 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 4386cde..9f168cb 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -11,8 +11,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.10', '3.11', '3.12'] - numpy-version: ['2.1', 'dev'] + python-version: ['3.9', '3.10', '3.11', '3.12'] + numpy-version: ['1.26', 'dev'] + exclude: + - python-version: '3.8' + numpy-version: 'dev' steps: - name: Checkout array-api-strict @@ -35,7 +38,7 @@ jobs: if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy; else - python -m pip install 'numpy>=${{ matrix.numpy-version }},<${{ matrix.numpy-version }}.99'; + python -m pip install 'numpy>=1.26,<2.0'; fi python -m pip install ${GITHUB_WORKSPACE}/array-api-strict python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a03c045..d8124d4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -5,8 +5,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.10', '3.11', '3.12'] - numpy-version: ['2.1', 'dev'] + python-version: ['3.9', '3.10', '3.11', '3.12'] + numpy-version: ['1.26', 'dev'] + exclude: + - python-version: '3.8' + numpy-version: 'dev' fail-fast: true steps: - uses: actions/checkout@v4 @@ -19,7 +22,7 @@ jobs: if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy; else - python -m pip install 'numpy>=${{ matrix.numpy-version }},<${{ matrix.numpy-version }}.99'; + python -m pip install 'numpy>=1.26,<2.0'; fi python -m pip install -r requirements-dev.txt - name: Run Tests diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index cbda499..ff43660 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -16,12 +16,6 @@ """ -import numpy as np -from numpy.lib import NumpyVersion - -if NumpyVersion(np.__version__) < NumpyVersion('2.1.0'): - raise ImportError("array-api-strict requires NumPy >= 2.1.0") - __all__ = [] # Warning: __array_api_version__ could change globally with diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 76cdfac..53669d1 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -162,7 +162,19 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None if _allow_array: if self._device != CPU_DEVICE: raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.") - return np.asarray(self._array, dtype=dtype, copy=copy) + # copy keyword is new in 2.0.0; for older versions don't use it + # retry without that keyword. + if np.__version__[0] < '2': + return np.asarray(self._array, dtype=dtype) + elif np.__version__.startswith('2.0.0-dev0'): + # Handle dev version for which we can't know based on version + # number whether or not the copy keyword is supported. + try: + return np.asarray(self._array, dtype=dtype, copy=copy) + except TypeError: + return np.asarray(self._array, dtype=dtype) + else: + return np.asarray(self._array, dtype=dtype, copy=copy) raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported") # These are various helper functions to make the array behavior match the @@ -574,14 +586,24 @@ def __dlpack__( if copy is not _default: raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") - kwargs = {'stream': stream} - if max_version is not _default: - kwargs['max_version'] = max_version - if dl_device is not _default: - kwargs['dl_device'] = dl_device - if copy is not _default: - kwargs['copy'] = copy - return self._array.__dlpack__(**kwargs) + if np.__version__[0] < '2.1': + if max_version not in [_default, None]: + raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") + if dl_device not in [_default, None]: + raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") + if copy not in [_default, None]: + raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") + + return self._array.__dlpack__(stream=stream) + else: + kwargs = {'stream': stream} + if max_version is not _default: + kwargs['max_version'] = max_version + if dl_device is not _default: + kwargs['dl_device'] = dl_device + if copy is not _default: + kwargs['copy'] = copy + return self._array.__dlpack__(**kwargs) def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: """ diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 8d7705b..d6d3efa 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -83,6 +83,29 @@ def asarray( if isinstance(obj, Array) and device is None: device = obj.device + if np.__version__[0] < '2': + if copy is False: + # Note: copy=False is not yet implemented in np.asarray for + # NumPy 1 + + # Work around it by creating the new array and seeing if NumPy + # copies it. + if isinstance(obj, Array): + new_array = np.array(obj._array, copy=copy, dtype=_np_dtype) + if new_array is not obj._array: + raise ValueError("Unable to avoid copy while creating an array from given array.") + return Array._new(new_array, device=device) + elif _supports_buffer_protocol(obj): + # Buffer protocol will always support no-copy + return Array._new(np.array(obj, copy=copy, dtype=_np_dtype), device=device) + else: + # No-copy is unsupported for Python built-in types. + raise ValueError("Unable to avoid copy while creating an array from given object.") + + if copy is None: + # NumPy 1 treats copy=False the same as the standard copy=None + copy = False + if isinstance(obj, Array): return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device) if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)): diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 0480f00..96fd31e 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -456,19 +456,31 @@ def dlpack_2023_12(api_version): set_array_api_strict_flags(api_version=api_version) a = asarray([1, 2, 3], dtype=int8) - - # Do not error + # Never an error a.__dlpack__() - a.__dlpack__(dl_device=CPU_DEVICE) - a.__dlpack__(dl_device=None) - a.__dlpack__(max_version=(1, 0)) - a.__dlpack__(max_version=None) - a.__dlpack__(copy=False) - a.__dlpack__(copy=True) - a.__dlpack__(copy=None) - - x = np.from_dlpack(a) - assert isinstance(x, np.ndarray) - assert x.dtype == np.int8 - assert x.shape == (3,) - assert np.all(x == np.asarray([1, 2, 3])) + + + if np.__version__ < '2.1': + exception = NotImplementedError if api_version >= '2023.12' else ValueError + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=CPU_DEVICE)) + pytest.raises(exception, lambda: + a.__dlpack__(dl_device=None)) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=(1, 0))) + pytest.raises(exception, lambda: + a.__dlpack__(max_version=None)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=False)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=True)) + pytest.raises(exception, lambda: + a.__dlpack__(copy=None)) + else: + a.__dlpack__(dl_device=CPU_DEVICE) + a.__dlpack__(dl_device=None) + a.__dlpack__(max_version=(1, 0)) + a.__dlpack__(max_version=None) + a.__dlpack__(copy=False) + a.__dlpack__(copy=True) + a.__dlpack__(copy=None) diff --git a/requirements-dev.txt b/requirements-dev.txt index 62673a8..137e973 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,3 @@ pytest hypothesis -numpy>=2.1 +numpy diff --git a/requirements.txt b/requirements.txt index 92ce2ba..24ce15a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -numpy>=2.1 +numpy diff --git a/setup.py b/setup.py index 89e85e5..29a94df 100644 --- a/setup.py +++ b/setup.py @@ -15,15 +15,14 @@ long_description_content_type="text/markdown", url="https://data-apis.org/array-api-strict/", license="MIT", - python_requires=">=3.10", - install_requires=["numpy>=2.1"], + python_requires=">=3.9", + install_requires=["numpy"], classifiers=[ "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "Intended Audience :: Developers", "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", ], From 476f45ab366f5193fe745a5757778283c80544b3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 8 Nov 2024 13:57:50 -0700 Subject: [PATCH 187/252] Fix dlpack test and use NumpyVersion --- array_api_strict/_array_object.py | 2 +- array_api_strict/tests/test_array_object.py | 42 +++++++++++++-------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 53669d1..c57d6ed 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -586,7 +586,7 @@ def __dlpack__( if copy is not _default: raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") - if np.__version__[0] < '2.1': + if np.lib.NumpyVersion(np.__version__) < '2.1.0': if max_version not in [_default, None]: raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") if dl_device not in [_default, None]: diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 96fd31e..4f843ba 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -448,7 +448,7 @@ def test_iter(): pytest.raises(TypeError, lambda: iter(ones((3, 3)))) @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) -def dlpack_2023_12(api_version): +def test_dlpack_2023_12(api_version): if api_version == '2021.12': with pytest.warns(UserWarning): set_array_api_strict_flags(api_version=api_version) @@ -459,25 +459,35 @@ def dlpack_2023_12(api_version): # Never an error a.__dlpack__() - - if np.__version__ < '2.1': - exception = NotImplementedError if api_version >= '2023.12' else ValueError - pytest.raises(exception, lambda: - a.__dlpack__(dl_device=CPU_DEVICE)) - pytest.raises(exception, lambda: + if api_version < '2023.12': + pytest.raises(ValueError, lambda: + a.__dlpack__(dl_device=a.__dlpack_device__())) + pytest.raises(ValueError, lambda: a.__dlpack__(dl_device=None)) - pytest.raises(exception, lambda: + pytest.raises(ValueError, lambda: a.__dlpack__(max_version=(1, 0))) - pytest.raises(exception, lambda: + pytest.raises(ValueError, lambda: a.__dlpack__(max_version=None)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=False)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=True)) - pytest.raises(exception, lambda: - a.__dlpack__(copy=None)) + pytest.raises(ValueError, lambda: + a.__dlpack__(copy=False)) + pytest.raises(ValueError, lambda: + a.__dlpack__(copy=True)) + pytest.raises(ValueError, lambda: + a.__dlpack__(copy=None)) + elif np.lib.NumpyVersion(np.__version__) < '2.1.0': + pytest.raises(NotImplementedError, lambda: + a.__dlpack__(dl_device=CPU_DEVICE)) + a.__dlpack__(dl_device=None) + pytest.raises(NotImplementedError, lambda: + a.__dlpack__(max_version=(1, 0))) + a.__dlpack__(max_version=None) + pytest.raises(NotImplementedError, lambda: + a.__dlpack__(copy=False)) + pytest.raises(NotImplementedError, lambda: + a.__dlpack__(copy=True)) + a.__dlpack__(copy=None) else: - a.__dlpack__(dl_device=CPU_DEVICE) + a.__dlpack__(dl_device=a.__dlpack_device__()) a.__dlpack__(dl_device=None) a.__dlpack__(max_version=(1, 0)) a.__dlpack__(max_version=None) From b20f8b491bbfbffa5084399295277345aa5d0e3b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 8 Nov 2024 13:59:17 -0700 Subject: [PATCH 188/252] Use NumpyVersion in asarray --- array_api_strict/_creation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index d6d3efa..e506bcc 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -83,7 +83,7 @@ def asarray( if isinstance(obj, Array) and device is None: device = obj.device - if np.__version__[0] < '2': + if np.lib.NumpyVersion(np.__version__) < '2.0.0': if copy is False: # Note: copy=False is not yet implemented in np.asarray for # NumPy 1 From 0c9bb5e21855d9baeef7f10cfc3876bb1eb8de9c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 8 Nov 2024 15:25:55 -0700 Subject: [PATCH 189/252] Re-enable __array__ Removing it caused issues for SciPy (https://github.com/data-apis/array-api-strict/issues/67). I have left the flag in to make it easy to remove it in the future. I also considered raising a warning in __array__, but this is also difficult to handle https://github.com/data-apis/array-api-strict/pull/91 --- array_api_strict/_array_object.py | 21 +++++++++++++-------- array_api_strict/tests/test_array_object.py | 17 +++++++++++++++++ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index c57d6ed..0de6b8a 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -66,7 +66,9 @@ def __hash__(self): _default = object() -_allow_array = False +# See https://github.com/data-apis/array-api-strict/issues/67 and the comment +# on __array__ below. +_allow_array = True class Array: """ @@ -147,15 +149,18 @@ def __repr__(self: Array, /) -> str: mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) return prefix + mid + suffix - # Disallow __array__, meaning calling `np.func()` on an array_api_strict - # array will give an error. If we don't explicitly disallow it, NumPy - # defaults to creating an object dtype array, which would lead to - # confusing error messages at best and surprising bugs at worst. - # - # The alternative of course is to just support __array__, which is what we - # used to do. But this isn't actually supported by the standard, so it can + # In the future, _allow_array will be set to False, which will disallow + # __array__. This means calling `np.func()` on an array_api_strict array + # will give an error. If we don't explicitly disallow it, NumPy defaults + # to creating an object dtype array, which would lead to confusing error + # messages at best and surprising bugs at worst. The reason for doing this + # is that __array__ is not actually supported by the standard, so it can # lead to code assuming np.asarray(other_array) would always work in the # standard. + # + # This was implemented historically for compatibility, and removing it has + # caused issues for some libraries (see + # https://github.com/data-apis/array-api-strict/issues/67). def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]: # We have to allow this to be internally enabled as there's no other # easy way to parse a list of Array objects in asarray(). diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 4f843ba..29b7d17 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -364,6 +364,23 @@ def test_array_conversion(): with pytest.raises(RuntimeError, match="Can not convert array"): asarray([a]) +def test__array__(): + # __array__ should work for now + a = ones((2, 3)) + np.array(a) + + # Test the _allow_array private global flag for disabling it in the + # future. + from .. import _array_object + original_value = _array_object._allow_array + try: + _array_object._allow_array = False + a = ones((2, 3)) + with pytest.raises(ValueError, match="Conversion from an array_api_strict array to a NumPy ndarray is not supported"): + np.array(a) + finally: + _array_object._allow_array = original_value + def test_allow_newaxis(): a = ones(5) indexed_a = a[None, :] From 1fe03a5f97e31b826dad2a5f125534b1a1efab9f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 8 Nov 2024 15:31:33 -0700 Subject: [PATCH 190/252] Add changelog for a 2.1.3 release Still need to merge https://github.com/data-apis/array-api-strict/pull/92 into this before doing the release. --- docs/changelog.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index fe72889..e048b1a 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,21 @@ # Changelog +## 2.1.3 (2024-11-08) + +## Major Changes + +- Revert the change to require NumPy >= 2.1 and Python >= 3.10 from + array-api-strict 2.1.2. array-api-strict now requires NumPy >= 1.21 and + Python >= 3.9, as before. These changes were made to improve the maintenance + of array-api-strict, but they caused some issues in upstream packages that + cannot yet support NumPy 2.0, so this will be postponed to a later date. + +- Revert the removal of `__array__` from array-api-strict 2.1.1. This caused + some difficulties for upstream libraries, so it will be postponed to a later + date. This is still planned because `__array__` is not part of the array API + standard. See https://github.com/data-apis/array-api-strict/issues/67 for + more discussion about this. + ## 2.1.2 (2024-11-07) ## Major Changes From fe4f34ad30ba59c1c981037dd181f6b4daf23b22 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 11 Nov 2024 15:30:59 -0700 Subject: [PATCH 191/252] Add a changelog for a 2.2 release --- docs/changelog.md | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index e048b1a..d33dc24 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,8 +1,39 @@ # Changelog +## 2.2 (2024-11-11) + +### Major Changes + +- Preliminary support for the draft 2024.12 version of the standard is now + implemented. This is disabled by default, but can be enabled with the [flags + API](array-api-strict-flags), e.g., by calling + `set_array_api_strict_flags(api_version='2024.12')` or setting + `ARRAY_API_STRICT_API_VERSION=2024.12`. + + Note that this support is still preliminary and still relatively untested. + Please [report any + issues](https://github.com/data-apis/array-api-strict/issues) you find. + + The following functions are implemented for 2024: + + - `diff` + - `nextafter` + - `reciprocal` + - `take_along_axis` + - The `'max dimensions'` key of `__array_namespace_info__().capabilities()`. + + Some planned changes to the 2024.12 standard, including scalar support for + array functions, is not yet implemented but will be in a future version. + +### Minor Changes + +- `__array_namespace_info__()` now returns a class instead of a module. This + prevents extraneous names that aren't part of the standard from appearing on + it. + ## 2.1.3 (2024-11-08) -## Major Changes +### Major Changes - Revert the change to require NumPy >= 2.1 and Python >= 3.10 from array-api-strict 2.1.2. array-api-strict now requires NumPy >= 1.21 and @@ -18,7 +49,7 @@ ## 2.1.2 (2024-11-07) -## Major Changes +### Major Changes - array-api-strict now requires NumPy >= 2.1 and Python >= 3.10 From dcf95bf0ea178e1b89e18b664225242b72be2a62 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 11 Nov 2024 15:51:10 -0700 Subject: [PATCH 192/252] Add releasing documentation --- docs/index.md | 1 + docs/releasing.md | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 docs/releasing.md diff --git a/docs/index.md b/docs/index.md index 12aadbb..65006b6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -197,4 +197,5 @@ git_filter_repo.py --path numpy/array_api/ --path-rename numpy/array_api:array_a api.rst changelog.md +releasing.md ``` diff --git a/docs/releasing.md b/docs/releasing.md new file mode 100644 index 0000000..db7dba7 --- /dev/null +++ b/docs/releasing.md @@ -0,0 +1,37 @@ +# Releasing + +To release array-api-strict: + +- Create a release branch and make a PR on GitHub. + +- Update `changelog.md` with the changes for the release. + +- Make sure the CI is passing on the release branch PR. Also double check that + you have properly pulled `main` and merged it into the release branch so + that the branch contains all the necessary changes for the release. + +- When you are ready to make the release, create a tag with the release number + + ``` + git tag -a 2.2 -m "array-api-strict 2.2" + ``` + + and push it up to GitHub + + ``` + git push origin --tags + ``` + + This will trigger the `publish-package` build on GitHub Actions. Make sure + that build works correctly and pushes the release up to PyPI. If something + goes wrong, you may need to delete the tag from GitHub and try again. + + Note that the `array_api_strict.__version__` version as well as the version + in the package metadata is all automatically computed from the tag, so it is + not necessary to update the version anywhere else. + +- Once the release is published, you can merge the PR. + +- The conda-forge bot will automatically send a PR to the + [array-api-strict-feedstock](https://github.com/conda-forge/array-api-strict-feedstock) + updating the version, which you should merge. From b7d5bad0c2e666a6fbc457a6356e6b7ce2b9b8a8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 12 Nov 2024 16:03:14 -0700 Subject: [PATCH 193/252] Use checkboxes for the releasing docs --- docs/_static/custom.css | 14 ++++++++++++++ docs/conf.py | 3 ++- docs/releasing.md | 12 ++++++------ 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/docs/_static/custom.css b/docs/_static/custom.css index bac0498..c712f02 100644 --- a/docs/_static/custom.css +++ b/docs/_static/custom.css @@ -10,3 +10,17 @@ body { html { scroll-behavior: auto; } + +/* Make checkboxes from the tasklist extension ('- [ ]' in Markdown) not add bullet points to the checkboxes. + +This can be removed once https://github.com/executablebooks/mdit-py-plugins/issues/59 is addressed. +*/ + +.contains-task-list { + list-style: none; +} + +/* Make the checkboxes indented like they are bullets */ +.task-list-item-checkbox { + margin: 0 0.2em 0.25em -1.4em; +} diff --git a/docs/conf.py b/docs/conf.py index e4c66d7..14d7503 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -36,7 +36,8 @@ templates_path = ['_templates'] exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] -myst_enable_extensions = ["dollarmath", "linkify"] +myst_enable_extensions = ["dollarmath", "linkify", "tasklist"] +myst_enable_checkboxes = True napoleon_use_rtype = False napoleon_use_param = False diff --git a/docs/releasing.md b/docs/releasing.md index db7dba7..5515411 100644 --- a/docs/releasing.md +++ b/docs/releasing.md @@ -2,15 +2,15 @@ To release array-api-strict: -- Create a release branch and make a PR on GitHub. +- [ ] Create a release branch and make a PR on GitHub. -- Update `changelog.md` with the changes for the release. +- [ ] Update `changelog.md` with the changes for the release. -- Make sure the CI is passing on the release branch PR. Also double check that +- [ ] Make sure the CI is passing on the release branch PR. Also double check that you have properly pulled `main` and merged it into the release branch so that the branch contains all the necessary changes for the release. -- When you are ready to make the release, create a tag with the release number +- [ ] When you are ready to make the release, create a tag with the release number ``` git tag -a 2.2 -m "array-api-strict 2.2" @@ -30,8 +30,8 @@ To release array-api-strict: in the package metadata is all automatically computed from the tag, so it is not necessary to update the version anywhere else. -- Once the release is published, you can merge the PR. +- [ ] Once the release is published, you can merge the PR. -- The conda-forge bot will automatically send a PR to the +- [ ] The conda-forge bot will automatically send a PR to the [array-api-strict-feedstock](https://github.com/conda-forge/array-api-strict-feedstock) updating the version, which you should merge. From e2ae30b09f400f797130dcc464a742d4481f1caf Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 19 Nov 2024 15:48:23 -0700 Subject: [PATCH 194/252] Conjugate the first argument in vecdot This is currently untested by the test suite (data-apis/array-api-tests#312) Fixes #97. --- array_api_strict/_linear_algebra_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py index 5ffdaa6..6af2a15 100644 --- a/array_api_strict/_linear_algebra_functions.py +++ b/array_api_strict/_linear_algebra_functions.py @@ -86,5 +86,5 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1_ = np.moveaxis(x1_, axis, -1) x2_ = np.moveaxis(x2_, axis, -1) - res = x1_[..., None, :] @ x2_[..., None] + res = np.conj(x1_[..., None, :]) @ x2_[..., None] return Array._new(res[..., 0, 0], device=x1.device) From 84573542760b325ac96d3aa7bd696a4d2dd0f7cc Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 26 Nov 2024 16:09:39 +0200 Subject: [PATCH 195/252] ENH: prohibit astype(complex, not complex) --- array_api_strict/_data_type_functions.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 046dfc7..acd8967 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -42,6 +42,13 @@ def astype( if not copy and dtype == x.dtype: return x + + if isdtype(x.dtype, 'complex floating') and not isdtype(dtype, 'complex floating'): + raise TypeError( + f'The Array API standard stipulates that casting {x.dtype} to {dtype} should not be permitted. ' + 'array-api-strict thus prohibits this conversion.' + ) + return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy), device=device) From 6bf7ac259ab059248b68caf81d046794fc8890be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Fri, 29 Nov 2024 12:02:16 +0100 Subject: [PATCH 196/252] BUG: fix import error against Python's optimized mode --- array_api_strict/_flags.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 2863e5f..279b0e7 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -169,12 +169,13 @@ def set_array_api_strict_flags( set(default_extensions)) # We have to do this separately or it won't get added as the docstring -set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( - supported_versions=supported_versions, - default_version=default_version, - default_extensions=default_extensions, - draft_version=draft_version, -) +if set_array_api_strict_flags.__doc__ is not None: + set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( + supported_versions=supported_versions, + default_version=default_version, + default_extensions=default_extensions, + draft_version=draft_version, + ) def get_array_api_strict_flags(): """ From 80cdbeac0a8dea90aa6e4601752f7dddf37a92bf Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 4 Dec 2024 20:33:10 +0000 Subject: [PATCH 197/252] Bump dawidd6/action-download-artifact from 6 to 7 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 6 to 7 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v6...v7) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 9aa379d..4106b88 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v6 + uses: dawidd6/action-download-artifact@v7 with: workflow: docs-build.yml name: docs-build From e68ea3c94a2f703a5e1bb2a01bedfbde218c49ac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 21:21:07 +0000 Subject: [PATCH 198/252] Bump pypa/gh-action-pypi-publish in the actions group Bumps the actions group with 1 update: [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish). Updates `pypa/gh-action-pypi-publish` from 1.12.2 to 1.12.3 - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.12.2...v1.12.3) --- updated-dependencies: - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-patch dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/publish-package.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 825b0a0..66c5cc6 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -97,14 +97,14 @@ jobs: # if: >- # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - # uses: pypa/gh-action-pypi-publish@v1.12.2 + # uses: pypa/gh-action-pypi-publish@v1.12.3 # with: # repository-url: https://test.pypi.org/legacy/ # print-hash: true - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@v1.12.2 + uses: pypa/gh-action-pypi-publish@v1.12.3 with: print-hash: true From 202c46bb1afaa892a99afe8e96093811bae55f37 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Mon, 9 Dec 2024 17:55:52 +0000 Subject: [PATCH 199/252] bug: where: check `condition` is boolean --- array_api_strict/_searching_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 0d7c0c8..5460b30 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -1,7 +1,7 @@ from __future__ import annotations from ._array_object import Array -from ._dtypes import _result_type, _real_numeric_dtypes +from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool from ._flags import requires_data_dependent_shapes, requires_api_version from typing import TYPE_CHECKING @@ -80,6 +80,9 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) + + if condition.dtype != _bool: + raise TypeError("`condition` must be have a boolean data type") if len({a.device for a in (condition, x1, x2)}) > 1: raise ValueError("where inputs must all be on the same device") From f8a6a9eb6c66799557d3af182500dbb4fbef1af7 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Tue, 10 Dec 2024 22:34:55 +0000 Subject: [PATCH 200/252] BUG: `from_dlpack`: fix default device --- array_api_strict/_creation_functions.py | 2 ++ array_api_strict/tests/test_creation_functions.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index e506bcc..460dba9 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -226,6 +226,8 @@ def from_dlpack( # Going to wait for upstream numpy support if device is not _default: _check_device(device) + else: + device = None if copy not in [_default, None]: raise NotImplementedError("The copy argument to from_dlpack is not yet implemented") diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index c93a08a..fc4e3cb 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -236,3 +236,10 @@ def from_dlpack_2023_12(api_version): pytest.raises(exception, lambda: from_dlpack(capsule, copy=False)) pytest.raises(exception, lambda: from_dlpack(capsule, copy=True)) pytest.raises(exception, lambda: from_dlpack(capsule, copy=None)) + + +def test_from_dlpack_default_device(): + x = asarray([1, 2, 3]) + y = from_dlpack(x) + z = from_dlpack(np.asarray([1, 2, 3])) + assert x.device == y.device == z.device == CPU_DEVICE From 4ac92551f4e56a8ee5ae50b15906db91ab453a99 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 17:31:59 +0200 Subject: [PATCH 201/252] ENH: allow mean(complex) in 2024.12 --- array_api_strict/_statistical_functions.py | 11 ++++++++-- .../tests/test_statistical_functions.py | 20 +++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 6ea9746..f06785c 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -3,6 +3,7 @@ from ._dtypes import ( _real_floating_dtypes, _real_numeric_dtypes, + _floating_dtypes, _numeric_dtypes, ) from ._array_object import Array @@ -65,8 +66,14 @@ def mean( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> Array: - if x.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in mean") + + if get_array_api_strict_flags()['api_version'] > '2023.12': + allowed_dtypes = _floating_dtypes + else: + allowed_dtypes = _real_floating_dtypes + + if x.dtype not in allowed_dtypes: + raise TypeError("Only floating-point dtypes are allowed in mean") return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims), device=x.device) diff --git a/array_api_strict/tests/test_statistical_functions.py b/array_api_strict/tests/test_statistical_functions.py index 7f2a457..c97670d 100644 --- a/array_api_strict/tests/test_statistical_functions.py +++ b/array_api_strict/tests/test_statistical_functions.py @@ -1,3 +1,4 @@ +import cmath import pytest from .._flags import set_array_api_strict_flags @@ -37,3 +38,22 @@ def test_sum_prod_trace_2023_12(func_name): assert func(a_real).dtype == xp.float32 assert func(a_complex).dtype == xp.complex64 assert func(a_int).dtype == xp.int64 + + +# mean(complex-valued array) is allowed from 2024.12 onwards +def test_mean_complex(): + a = xp.asarray([1j, 2j, 3j]) + + set_array_api_strict_flags(api_version='2023.12') + with pytest.raises(TypeError): + xp.mean(a) + + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version='2024.12') + m = xp.mean(a) + assert cmath.isclose(complex(m), 2j) + + # mean of integer arrays is still not allowed + with pytest.raises(TypeError): + xp.mean(xp.arange(3)) + From 488c7336dba8195347007a852873b52e0b821d62 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 25 Dec 2024 18:32:52 +0200 Subject: [PATCH 202/252] MAINT: make __str__ less ambiguous For 0D arrays, __str__ used to look like a scalar: In [2]: x = xp.asarray(3) In [3]: print(x) 3 So make it clearly an array. --- array_api_strict/_array_object.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 0de6b8a..a917441 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -126,12 +126,6 @@ def __new__(cls, *args, **kwargs): # These functions are not required by the spec, but are implemented for # the sake of usability. - def __str__(self: Array, /) -> str: - """ - Performs the operation __str__. - """ - return self._array.__str__().replace("array", "Array") - def __repr__(self: Array, /) -> str: """ Performs the operation __repr__. @@ -149,6 +143,8 @@ def __repr__(self: Array, /) -> str: mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix) return prefix + mid + suffix + __str__ = __repr__ + # In the future, _allow_array will be set to False, which will disallow # __array__. This means calling `np.func()` on an array_api_strict array # will give an error. If we don't explicitly disallow it, NumPy defaults From c4419717554857ef1b220e4d3a4b5f94c19ac4e1 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 2 Jan 2025 14:59:27 +0000 Subject: [PATCH 203/252] isdtype() should raise if parameter is not a dtype --- .gitignore | 8 ++++---- array_api_strict/_data_type_functions.py | 3 +++ array_api_strict/tests/test_data_type_functions.py | 6 ++++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index dbce267..f69e911 100644 --- a/.gitignore +++ b/.gitignore @@ -128,12 +128,12 @@ ENV/ env.bak/ venv.bak/ -# Spyder project settings +# Project settings +.idea +.ropeproject .spyderproject .spyproject - -# Rope project settings -.ropeproject +.vscode # mkdocs documentation /site diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index acd8967..5af46d2 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -167,6 +167,9 @@ def isdtype( https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html for more details """ + if not isinstance(dtype, _DType): + raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}") + if isinstance(kind, tuple): # Disallow nested tuples if any(isinstance(k, tuple) for k in kind): diff --git a/array_api_strict/tests/test_data_type_functions.py b/array_api_strict/tests/test_data_type_functions.py index 40cab55..488eab7 100644 --- a/array_api_strict/tests/test_data_type_functions.py +++ b/array_api_strict/tests/test_data_type_functions.py @@ -31,15 +31,17 @@ def test_can_cast(from_, to, expected): def test_isdtype_strictness(): assert_raises(TypeError, lambda: isdtype(float64, 64)) assert_raises(ValueError, lambda: isdtype(float64, 'f8')) - assert_raises(TypeError, lambda: isdtype(float64, (('integral',),))) + assert_raises(TypeError, lambda: isdtype(float64, None)) + assert_raises(TypeError, lambda: isdtype(np.float64, float64)) + assert_raises(TypeError, lambda: isdtype(asarray(1.0), float64)) + with assert_raises(TypeError), warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") isdtype(float64, np.object_) assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) - assert_raises(TypeError, lambda: isdtype(float64, None)) with assert_raises(TypeError), warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") isdtype(float64, np.float64) From f00a882206c86369360f92ca3e550b9ca363d89c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 20:29:50 +0200 Subject: [PATCH 204/252] ENH: add count_nonzero --- array_api_strict/__init__.py | 4 ++-- array_api_strict/_searching_functions.py | 20 +++++++++++++++++++- array_api_strict/tests/test_flags.py | 1 + 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 98b0e95..da66c9e 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -293,9 +293,9 @@ __all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"] -from ._searching_functions import argmax, argmin, nonzero, searchsorted, where +from ._searching_functions import argmax, argmin, nonzero, count_nonzero, searchsorted, where -__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"] +__all__ += ["argmax", "argmin", "nonzero", "count_nonzero", "searchsorted", "where"] from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 5460b30..df91e44 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Optional, Tuple + from typing import Literal, Optional, Tuple, Union import numpy as np @@ -45,6 +45,24 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]: raise ValueError("nonzero is not allowed on 0-dimensional arrays") return tuple(Array._new(i, device=x.device) for i in np.nonzero(x._array)) + +@requires_api_version('2024.12') +def count_nonzero( + x: Array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, +) -> Array: + """ + Array API compatible wrapper for :py:func:`np.count_nonzero ` + + See its docstring for more information. + """ + arr = np.count_nonzero(x._array, axis=axis, keepdims=keepdims) + return Array._new(np.asarray(arr), device=x.device) + + @requires_api_version('2023.12') def searchsorted( x1: Array, diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index e0b004b..ebee415 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -307,6 +307,7 @@ def test_api_version_2023_12(func_name): 'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])), 'take_along_axis': lambda: xp.take_along_axis(xp.zeros((2, 3)), xp.zeros((1, 4), dtype=xp.int64)), + 'count_nonzero': lambda: xp.count_nonzero(xp.arange(3)), } @pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys()) From 912362c8b60f9ecb5b37752ec60ebd4382fa3f88 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 21:48:32 +0200 Subject: [PATCH 205/252] ENH: add cumulative_prod (untested) --- array_api_strict/__init__.py | 4 +- array_api_strict/_dtypes.py | 28 +++++++++++++ array_api_strict/_statistical_functions.py | 49 ++++++++++++++++++++++ 3 files changed, 79 insertions(+), 2 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index da66c9e..27dec55 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -305,9 +305,9 @@ __all__ += ["argsort", "sort"] -from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var +from ._statistical_functions import cumulative_sum, cumulative_prod, max, mean, min, prod, std, sum, var -__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] +__all__ += ["cumulative_sum", "cumulative_prod", "max", "mean", "min", "prod", "std", "sum", "var"] from ._utility_functions import all, any, diff diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index b51ed92..cf7581d 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -127,6 +127,34 @@ def __hash__(self): } +def _bit_width(dtype): + """The bit width of an integer dtype""" + if dtype == int8 or dtype == uint8: + return 8 + elif dtype == int16 or dtype == uint16: + return 16 + elif dtype == int32 or dtype == uint32: + return 32 + elif dtype == int64 or dtype == uint64: + return 64 + else: + raise ValueError(f"_bit_width: {dtype = } not understood.") + + +def _get_unsigned_from_signed(dtype): + """Return an unsigned integral dtype to match the input dtype.""" + if dtype == int8: + return uint8 + elif dtype == int16: + return uint16 + elif dtype == int32: + return uint32 + elif dtype == int64: + return uint64 + else: + raise ValueError(f"_unsigned_from_signed: {dtype = } not understood.") + + # Note: the spec defines a restricted type promotion table compared to NumPy. # In particular, cross-kind promotions like integer + float or boolean + # integer are not allowed, even for functions that accept both kinds. diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index f06785c..53f1d2f 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -5,7 +5,10 @@ _real_numeric_dtypes, _floating_dtypes, _numeric_dtypes, + _integer_dtypes ) +from . import _dtypes +from . import _info from ._array_object import Array from ._dtypes import float32, complex64 from ._flags import requires_api_version, get_array_api_strict_flags @@ -47,6 +50,52 @@ def cumulative_sum( x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis) return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device) + +@requires_api_version('2024.12') +def cumulative_prod( + x: Array, + /, + *, + axis: Optional[int] = None, + dtype: Optional[Dtype] = None, + include_initial: bool = False, +) -> Array: + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in cumulative_prod") + if x.ndim == 0: + raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod") + + # TODO: either all this is done by numpy's cumprod (?), or cumulative_sum should follow the same dance. + if dtype is None: + if x.dtype in _integer_dtypes: + default_int = _info.__array_namespace_info__().default_dtypes()["integral"] + if _dtypes._bit_width(x.dtype) < _dtypes._bit_width(default_int): + if x.dtype in _dtypes._unsigned_integer_dtypes: + # find the unsigned integer of the same width as `default_int` + dtype = _dtypes._get_unsigned_from_signed(default_int) + else: + dtype = default_int + else: + dtype = x.dtype + else: + dtype = x.dtype + else: + if x.dtype != dtype: + x = xp.astype(dtype) + + if axis is None: + if x.ndim > 1: + raise ValueError("axis must be specified in cumulative_prod for more than one dimension") + axis = 0 + + # np.cumprod does not support include_initial + if include_initial: + if axis < 0: + axis += x.ndim + x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dtype), x], axis=axis) + return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype._np_dtype), device=x.device) + + def max( x: Array, /, From 3ff4ca641b3b4a7323b9627f8f82548928c0ce4d Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 22:02:10 +0200 Subject: [PATCH 206/252] MAINT: simplify cumulative_prod --- array_api_strict/_dtypes.py | 28 ---------------------- array_api_strict/_statistical_functions.py | 28 ++++------------------ 2 files changed, 5 insertions(+), 51 deletions(-) diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index cf7581d..b51ed92 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -127,34 +127,6 @@ def __hash__(self): } -def _bit_width(dtype): - """The bit width of an integer dtype""" - if dtype == int8 or dtype == uint8: - return 8 - elif dtype == int16 or dtype == uint16: - return 16 - elif dtype == int32 or dtype == uint32: - return 32 - elif dtype == int64 or dtype == uint64: - return 64 - else: - raise ValueError(f"_bit_width: {dtype = } not understood.") - - -def _get_unsigned_from_signed(dtype): - """Return an unsigned integral dtype to match the input dtype.""" - if dtype == int8: - return uint8 - elif dtype == int16: - return uint16 - elif dtype == int32: - return uint32 - elif dtype == int64: - return uint64 - else: - raise ValueError(f"_unsigned_from_signed: {dtype = } not understood.") - - # Note: the spec defines a restricted type promotion table compared to NumPy. # In particular, cross-kind promotions like integer + float or boolean + # integer are not allowed, even for functions that accept both kinds. diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 53f1d2f..461ee04 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -5,14 +5,11 @@ _real_numeric_dtypes, _floating_dtypes, _numeric_dtypes, - _integer_dtypes ) -from . import _dtypes -from . import _info from ._array_object import Array from ._dtypes import float32, complex64 from ._flags import requires_api_version, get_array_api_strict_flags -from ._creation_functions import zeros +from ._creation_functions import zeros, ones from ._manipulation_functions import concat from typing import TYPE_CHECKING @@ -65,23 +62,8 @@ def cumulative_prod( if x.ndim == 0: raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod") - # TODO: either all this is done by numpy's cumprod (?), or cumulative_sum should follow the same dance. - if dtype is None: - if x.dtype in _integer_dtypes: - default_int = _info.__array_namespace_info__().default_dtypes()["integral"] - if _dtypes._bit_width(x.dtype) < _dtypes._bit_width(default_int): - if x.dtype in _dtypes._unsigned_integer_dtypes: - # find the unsigned integer of the same width as `default_int` - dtype = _dtypes._get_unsigned_from_signed(default_int) - else: - dtype = default_int - else: - dtype = x.dtype - else: - dtype = x.dtype - else: - if x.dtype != dtype: - x = xp.astype(dtype) + if dtype is not None: + dtype = dtype._np_dtype if axis is None: if x.ndim > 1: @@ -92,8 +74,8 @@ def cumulative_prod( if include_initial: if axis < 0: axis += x.ndim - x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dtype), x], axis=axis) - return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype._np_dtype), device=x.device) + x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis) + return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype), device=x.device) def max( From 83ac04f1bff0750305543b96edbe05575c0a6e36 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 22:04:00 +0200 Subject: [PATCH 207/252] BUG: fix dtype of include_initial in cumulative_sum In `concat([zeros(...), x])` zeros must have the same dtype as `x`. --- array_api_strict/_statistical_functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 461ee04..e41e7ef 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -31,7 +31,6 @@ def cumulative_sum( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in cumulative_sum") - dt = x.dtype if dtype is None else dtype if dtype is not None: dtype = dtype._np_dtype @@ -44,7 +43,7 @@ def cumulative_sum( if include_initial: if axis < 0: axis += x.ndim - x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis) + x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis) return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device) From 61bf3c1e790685a3c0de2ef27e9e105623f351f1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 16 Dec 2024 22:07:15 +0200 Subject: [PATCH 208/252] TST: add cumulative_prod to test_flags --- array_api_strict/tests/test_flags.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index ebee415..dcfc20d 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -308,6 +308,7 @@ def test_api_version_2023_12(func_name): 'take_along_axis': lambda: xp.take_along_axis(xp.zeros((2, 3)), xp.zeros((1, 4), dtype=xp.int64)), 'count_nonzero': lambda: xp.count_nonzero(xp.arange(3)), + 'cumulative_prod': lambda: xp.cumulative_prod(xp.arange(1, 5)), } @pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys()) From 820788b3f01d1b525e4a50fc33494ad8323f39bf Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 7 Jan 2025 14:14:58 +0200 Subject: [PATCH 209/252] ENH: allow python scalars as inputs to result_type --- array_api_strict/_data_type_functions.py | 27 ++++++++++++++++--- .../tests/test_data_type_functions.py | 23 ++++++++++++++-- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 5af46d2..1643043 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -197,7 +197,7 @@ def isdtype( else: raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}") -def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype: +def result_type(*arrays_and_dtypes: Union[Array, Dtype, int, float, complex, bool]) -> Dtype: """ Array API compatible wrapper for :py:func:`np.result_type `. @@ -208,19 +208,40 @@ def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype: # too many extra type promotions like int64 + uint64 -> float64, and does # value-based casting on scalar arrays. A = [] + scalars = [] for a in arrays_and_dtypes: if isinstance(a, Array): a = a.dtype + elif isinstance(a, (bool, int, float, complex)): + scalars.append(a) elif isinstance(a, np.ndarray) or a not in _all_dtypes: raise TypeError("result_type() inputs must be array_api arrays or dtypes") A.append(a) + # remove python scalars + A = [a for a in A if not isinstance(a, (bool, int, float, complex))] + if len(A) == 0: raise ValueError("at least one array or dtype is required") elif len(A) == 1: - return A[0] + result = A[0] else: t = A[0] for t2 in A[1:]: t = _result_type(t, t2) - return t + result = t + + if len(scalars) == 0: + return result + + if get_array_api_strict_flags()['api_version'] <= '2023.12': + raise TypeError("result_type() inputs must be array_api arrays or dtypes") + + # promote python scalars given the result_type for all arrays/dtypes + from ._creation_functions import empty + arr = empty(1, dtype=result) + for s in scalars: + x = arr._promote_scalar(s) + result = _result_type(x.dtype, result) + + return result diff --git a/array_api_strict/tests/test_data_type_functions.py b/array_api_strict/tests/test_data_type_functions.py index 488eab7..863d3d4 100644 --- a/array_api_strict/tests/test_data_type_functions.py +++ b/array_api_strict/tests/test_data_type_functions.py @@ -6,9 +6,9 @@ import numpy as np from .._creation_functions import asarray -from .._data_type_functions import astype, can_cast, isdtype +from .._data_type_functions import astype, can_cast, isdtype, result_type from .._dtypes import ( - bool, int8, int16, uint8, float64, + bool, int8, int16, uint8, float64, int64 ) from .._flags import set_array_api_strict_flags @@ -70,3 +70,22 @@ def astype_device(api_version): else: pytest.raises(TypeError, lambda: astype(a, int8, device=None)) pytest.raises(TypeError, lambda: astype(a, int8, device=a.device)) + + +@pytest.mark.parametrize("api_version", ['2023.12', '2024.12']) +def test_result_type_py_scalars(api_version): + if api_version <= '2023.12': + set_array_api_strict_flags(api_version=api_version) + + with pytest.raises(TypeError): + result_type(int16, 3) + else: + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + + assert result_type(int8, 3) == int8 + assert result_type(uint8, 3) == uint8 + assert result_type(float64, 3) == float64 + + with pytest.raises(TypeError): + result_type(int64, True) From c2526d1357d31d62e924c4e6fe97a874bc739e4f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 7 Jan 2025 14:36:59 +0200 Subject: [PATCH 210/252] TST: fix tests for result_type(scalars) --- array_api_strict/tests/test_manipulation_functions.py | 2 +- array_api_strict/tests/test_validation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_strict/tests/test_manipulation_functions.py b/array_api_strict/tests/test_manipulation_functions.py index 9969651..bd247ee 100644 --- a/array_api_strict/tests/test_manipulation_functions.py +++ b/array_api_strict/tests/test_manipulation_functions.py @@ -11,7 +11,7 @@ def test_concat_errors(): - assert_raises(TypeError, lambda: concat((1, 1), axis=None)) + assert_raises((TypeError, ValueError), lambda: concat((1, 1), axis=None)) assert_raises(TypeError, lambda: concat([asarray([1], dtype=int8), asarray([1], dtype=float64)])) diff --git a/array_api_strict/tests/test_validation.py b/array_api_strict/tests/test_validation.py index 035b6f4..bd76ec6 100644 --- a/array_api_strict/tests/test_validation.py +++ b/array_api_strict/tests/test_validation.py @@ -18,7 +18,7 @@ def p(func: Callable, *args, **kwargs): [ p(xp.can_cast, 42, xp.int8), p(xp.can_cast, xp.int8, 42), - p(xp.result_type, 42), + p(xp.result_type, "42"), ], ) def test_raises_on_invalid_types(func, args, kwargs): From ea96e9b2dc186eacaec33df67c68d716f17bfb18 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 22 Nov 2024 11:07:40 +0200 Subject: [PATCH 211/252] ENH: allow python scalars in binary elementwise functions Allow func(array, scalar) and func(scalar, array), raise on func(scalar, scalar) if API_VERSION>=2024.12 cross-ref https://github.com/data-apis/array-api/issues/807 To make sure it is all uniform, 1. Generate all binary "ufuncs" in a uniform way, with a decorator 2. Make binary "ufuncs" follow the same logic of the binary operators 3. Reuse the test loop of Array.__binop__ for binary "ufuncs" 4. (minor) in tests, reuse canonical names for dtype categories ("integer or boolean" vs "integer_or_boolean") --- array_api_strict/_array_object.py | 2 + array_api_strict/_elementwise_functions.py | 584 ++++-------------- array_api_strict/_helpers.py | 37 ++ array_api_strict/tests/test_array_object.py | 101 +-- .../tests/test_elementwise_functions.py | 54 +- 5 files changed, 266 insertions(+), 512 deletions(-) create mode 100644 array_api_strict/_helpers.py diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index a917441..47153e5 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -230,6 +230,8 @@ def _check_device(self, other): elif isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") + else: + raise TypeError(f"Cannot combine an Array with {type(other)}.") # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 7c64f67..3c4b3d8 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -10,17 +10,133 @@ _real_numeric_dtypes, _numeric_dtypes, _result_type, + _dtype_categories as _dtype_dtype_categories, ) from ._array_object import Array from ._flags import requires_api_version from ._creation_functions import asarray from ._data_type_functions import broadcast_to, iinfo +from ._helpers import _maybe_normalize_py_scalars from typing import Optional, Union import numpy as np +def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func): + """Base implementation of a binary function, `func_name`, defined for + dtypes from `dtype_category` + """ + x1, x2 = _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name) + + if x1.device != x2.device: + raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + # Call result type here just to raise on disallowed type combinations + _result_type(x1.dtype, x2.dtype) + x1, x2 = Array._normalize_two_args(x1, x2) + return Array._new(np_func(x1._array, x2._array), device=x1.device) + + +_binary_docstring_template=\ +""" +Array API compatible wrapper for :py:func:`np.%s `. + +See its docstring for more information. +""" + + +def create_binary_func(func_name, dtype_category, np_func): + def inner(x1: Array, x2: Array, /) -> Array: + return _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func) + return inner + + +# func_name: dtype_category (must match that from _dtypes.py) +_binary_funcs = { + "add": "numeric", + "atan2": "real floating-point", + "bitwise_and": "integer or boolean", + "bitwise_or": "integer or boolean", + "bitwise_xor": "integer or boolean", + "_bitwise_left_shift": "integer", # leading underscore deliberate + "_bitwise_right_shift": "integer", + # XXX: copysign: real fp or numeric? + "copysign": "real floating-point", + "divide": "floating-point", + "equal": "all", + "greater": "real numeric", + "greater_equal": "real numeric", + "less": "real numeric", + "less_equal": "real numeric", + "not_equal": "all", + "floor_divide": "real numeric", + "hypot": "real floating-point", + "logaddexp": "real floating-point", + "logical_and": "boolean", + "logical_or": "boolean", + "logical_xor": "boolean", + "maximum": "real numeric", + "minimum": "real numeric", + "multiply": "numeric", + "nextafter": "real floating-point", + "pow": "numeric", + "remainder": "real numeric", + "subtract": "numeric", +} + + +# map array-api-name : numpy-name +_numpy_renames = { + "atan2": "arctan2", + "_bitwise_left_shift": "left_shift", + "_bitwise_right_shift": "right_shift", + "pow": "power" +} + + +# create and attach functions to the module +for func_name, dtype_category in _binary_funcs.items(): + # sanity check + assert dtype_category in _dtype_dtype_categories + + numpy_name = _numpy_renames.get(func_name, func_name) + np_func = getattr(np, numpy_name) + + func = create_binary_func(func_name, dtype_category, np_func) + func.__name__ = func_name + + func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name) + + vars()[func_name] = func + + +copysign = requires_api_version('2023.12')(copysign) # noqa: F821 +hypot = requires_api_version('2023.12')(hypot) # noqa: F821 +maximum = requires_api_version('2023.12')(maximum) # noqa: F821 +minimum = requires_api_version('2023.12')(minimum) # noqa: F821 +nextafter = requires_api_version('2024.12')(nextafter) # noqa: F821 + + +def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: + is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0 + if is_negative: + raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") + return _bitwise_left_shift(x1, x2) # noqa: F821 +bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__ # noqa: F821 + + +def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: + is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0 + if is_negative: + raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") + return _bitwise_right_shift(x1, x2) # noqa: F821 +bitwise_right_shift.__doc__ = _bitwise_right_shift.__doc__ # noqa: F821 + + +# clean up to not pollute the namespace +del func, create_binary_func + + def abs(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.abs `. @@ -56,23 +172,6 @@ def acosh(x: Array, /) -> Array: return Array._new(np.arccosh(x._array), device=x.device) -def add(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.add `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in add") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.add(x1._array, x2._array), device=x1.device) - - # Note: the function name is different here def asin(x: Array, /) -> Array: """ @@ -109,23 +208,6 @@ def atan(x: Array, /) -> Array: return Array._new(np.arctan(x._array), device=x.device) -# Note: the function name is different here -def atan2(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arctan2 `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in atan2") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.arctan2(x1._array, x2._array), device=x1.device) - - # Note: the function name is different here def atanh(x: Array, /) -> Array: """ @@ -138,47 +220,6 @@ def atanh(x: Array, /) -> Array: return Array._new(np.arctanh(x._array), device=x.device) -def bitwise_and(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.bitwise_and `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if ( - x1.dtype not in _integer_or_boolean_dtypes - or x2.dtype not in _integer_or_boolean_dtypes - ): - raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_and(x1._array, x2._array), device=x1.device) - - -# Note: the function name is different here -def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.left_shift `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: - raise TypeError("Only integer dtypes are allowed in bitwise_left_shift") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - # Note: bitwise_left_shift is only defined for x2 nonnegative. - if np.any(x2._array < 0): - raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") - return Array._new(np.left_shift(x1._array, x2._array), device=x1.device) - - # Note: the function name is different here def bitwise_invert(x: Array, /) -> Array: """ @@ -191,67 +232,6 @@ def bitwise_invert(x: Array, /) -> Array: return Array._new(np.invert(x._array), device=x.device) -def bitwise_or(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.bitwise_or `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if ( - x1.dtype not in _integer_or_boolean_dtypes - or x2.dtype not in _integer_or_boolean_dtypes - ): - raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_or(x1._array, x2._array), device=x1.device) - - -# Note: the function name is different here -def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.right_shift `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes: - raise TypeError("Only integer dtypes are allowed in bitwise_right_shift") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - # Note: bitwise_right_shift is only defined for x2 nonnegative. - if np.any(x2._array < 0): - raise ValueError("bitwise_right_shift(x1, x2) is only defined for x2 >= 0") - return Array._new(np.right_shift(x1._array, x2._array), device=x1.device) - - -def bitwise_xor(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.bitwise_xor `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if ( - x1.dtype not in _integer_or_boolean_dtypes - or x2.dtype not in _integer_or_boolean_dtypes - ): - raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.bitwise_xor(x1._array, x2._array), device=x1.device) - - def ceil(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.ceil `. @@ -372,6 +352,7 @@ def _isscalar(a): out[ib] = b[ib] return Array._new(out, device=device) + def conj(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.conj `. @@ -382,22 +363,6 @@ def conj(x: Array, /) -> Array: raise TypeError("Only complex floating-point dtypes are allowed in conj") return Array._new(np.conj(x._array), device=x.device) -@requires_api_version('2023.12') -def copysign(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.copysign `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real numeric dtypes are allowed in copysign") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.copysign(x1._array, x2._array), device=x1.device) def cos(x: Array, /) -> Array: """ @@ -421,36 +386,6 @@ def cosh(x: Array, /) -> Array: return Array._new(np.cosh(x._array), device=x.device) -def divide(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.divide `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in divide") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.divide(x1._array, x2._array), device=x1.device) - - -def equal(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.equal `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.equal(x1._array, x2._array), device=x1.device) - - def exp(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.exp `. @@ -487,69 +422,6 @@ def floor(x: Array, /) -> Array: return Array._new(np.floor(x._array), device=x.device) -def floor_divide(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.floor_divide `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in floor_divide") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.floor_divide(x1._array, x2._array), device=x1.device) - - -def greater(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.greater `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in greater") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.greater(x1._array, x2._array), device=x1.device) - - -def greater_equal(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.greater_equal `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in greater_equal") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.greater_equal(x1._array, x2._array), device=x1.device) - -@requires_api_version('2023.12') -def hypot(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.hypot `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in hypot") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.hypot(x1._array, x2._array), device=x1.device) - def imag(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.imag `. @@ -594,38 +466,6 @@ def isnan(x: Array, /) -> Array: return Array._new(np.isnan(x._array), device=x.device) -def less(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.less `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in less") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.less(x1._array, x2._array), device=x1.device) - - -def less_equal(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.less_equal `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in less_equal") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.less_equal(x1._array, x2._array), device=x1.device) - - def log(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.log `. @@ -670,38 +510,6 @@ def log10(x: Array, /) -> Array: return Array._new(np.log10(x._array), device=x.device) -def logaddexp(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logaddexp `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in logaddexp") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logaddexp(x1._array, x2._array), device=x1.device) - - -def logical_and(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logical_and `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError("Only boolean dtypes are allowed in logical_and") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_and(x1._array, x2._array), device=x1.device) - - def logical_not(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.logical_not `. @@ -713,87 +521,6 @@ def logical_not(x: Array, /) -> Array: return Array._new(np.logical_not(x._array), device=x.device) -def logical_or(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logical_or `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError("Only boolean dtypes are allowed in logical_or") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_or(x1._array, x2._array), device=x1.device) - - -def logical_xor(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logical_xor `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes: - raise TypeError("Only boolean dtypes are allowed in logical_xor") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.logical_xor(x1._array, x2._array), device=x1.device) - -@requires_api_version('2023.12') -def maximum(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.maximum `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in maximum") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - # TODO: maximum(-0., 0.) is unspecified. Should we issue a warning/error - # in that case? - return Array._new(np.maximum(x1._array, x2._array), device=x1.device) - -@requires_api_version('2023.12') -def minimum(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.minimum `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in minimum") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.minimum(x1._array, x2._array), device=x1.device) - -def multiply(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.multiply `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in multiply") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.multiply(x1._array, x2._array), device=x1.device) - - def negative(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.negative `. @@ -805,34 +532,6 @@ def negative(x: Array, /) -> Array: return Array._new(np.negative(x._array), device=x.device) -@requires_api_version('2024.12') -def nextafter(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.nextafter `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in nextafter") - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.nextafter(x1._array, x2._array), device=x1.device) - -def not_equal(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.not_equal `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.not_equal(x1._array, x2._array), device=x1.device) - - def positive(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.positive `. @@ -844,23 +543,6 @@ def positive(x: Array, /) -> Array: return Array._new(np.positive(x._array), device=x.device) -# Note: the function name is different here -def pow(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.power `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in pow") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.power(x1._array, x2._array), device=x1.device) - - def real(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.real `. @@ -883,22 +565,6 @@ def reciprocal(x: Array, /) -> Array: raise TypeError("Only floating-point dtypes are allowed in reciprocal") return Array._new(np.reciprocal(x._array), device=x.device) -def remainder(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.remainder `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in remainder") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.remainder(x1._array, x2._array), device=x1.device) - - def round(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.round `. @@ -979,22 +645,6 @@ def sqrt(x: Array, /) -> Array: return Array._new(np.sqrt(x._array), device=x.device) -def subtract(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.subtract `. - - See its docstring for more information. - """ - if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in subtract") - # Call result type here just to raise on disallowed type combinations - _result_type(x1.dtype, x2.dtype) - x1, x2 = Array._normalize_two_args(x1, x2) - return Array._new(np.subtract(x1._array, x2._array), device=x1.device) - - def tan(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.tan `. diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py new file mode 100644 index 0000000..2258d29 --- /dev/null +++ b/array_api_strict/_helpers.py @@ -0,0 +1,37 @@ +"""Private helper routines. +""" + +from ._flags import get_array_api_strict_flags +from ._dtypes import _dtype_categories + +_py_scalars = (bool, int, float, complex) + + +def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name): + + flags = get_array_api_strict_flags() + if flags["api_version"] < "2024.12": + # scalars will fail at the call site + return x1, x2 + + _allowed_dtypes = _dtype_categories[dtype_category] + + if isinstance(x1, _py_scalars): + if isinstance(x2, _py_scalars): + raise TypeError(f"Two scalars not allowed, got {type(x1) = } and {type(x2) =}") + # x2 must be an array + if x2.dtype not in _allowed_dtypes: + raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x2.dtype}.") + x1 = x2._promote_scalar(x1) + + elif isinstance(x2, _py_scalars): + # x1 must be an array + if x1.dtype not in _allowed_dtypes: + raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x1.dtype}.") + x2 = x1._promote_scalar(x2) + else: + if x1.dtype not in _allowed_dtypes or x2.dtype not in _allowed_dtypes: + raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. " + f"Got {x1.dtype} and {x2.dtype}.") + return x1, x2 + diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 8f185f0..4535d99 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -96,12 +96,60 @@ def test_promoted_scalar_inherits_device(): assert y.device == device1 + +BIG_INT = int(1e30) + +def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT): + # Test array op scalar. From the spec, the following combinations + # are supported: + + # - Python bool for a bool array dtype, + # - a Python int within the bounds of the given dtype for integer array dtypes, + # - a Python int or float for real floating-point array dtypes + # - a Python int, float, or complex for complex floating-point array dtypes + + if ((dtypes == "all" + or dtypes == "numeric" and a.dtype in _numeric_dtypes + or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes + or dtypes == "integer" and a.dtype in _integer_dtypes + or dtypes == "integer or boolean" and a.dtype in _integer_or_boolean_dtypes + or dtypes == "boolean" and a.dtype in _boolean_dtypes + or dtypes == "floating-point" and a.dtype in _floating_dtypes + or dtypes == "real floating-point" and a.dtype in _real_floating_dtypes + ) + # bool is a subtype of int, which is why we avoid + # isinstance here. + and (a.dtype in _boolean_dtypes and type(s) == bool + or a.dtype in _integer_dtypes and type(s) == int + or a.dtype in _real_floating_dtypes and type(s) in [float, int] + or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] + )): + if a.dtype in _integer_dtypes and s == BIG_INT: + with assert_raises(OverflowError): + func(s) + return False + + else: + # Only test for no error + with suppress_warnings() as sup: + # ignore warnings from pow(BIG_INT) + sup.filter(RuntimeWarning, + "invalid value encountered in power") + func(s) + return True + + else: + with assert_raises(TypeError): + func(s) + return False + + def test_operators(): # For every operator, we test that it works for the required type # combinations and raises TypeError otherwise binary_op_dtypes = { "__add__": "numeric", - "__and__": "integer_or_boolean", + "__and__": "integer or boolean", "__eq__": "all", "__floordiv__": "real numeric", "__ge__": "real numeric", @@ -112,12 +160,12 @@ def test_operators(): "__mod__": "real numeric", "__mul__": "numeric", "__ne__": "all", - "__or__": "integer_or_boolean", + "__or__": "integer or boolean", "__pow__": "numeric", "__rshift__": "integer", "__sub__": "numeric", - "__truediv__": "floating", - "__xor__": "integer_or_boolean", + "__truediv__": "floating-point", + "__xor__": "integer or boolean", } # Recompute each time because of in-place ops def _array_vals(): @@ -128,8 +176,6 @@ def _array_vals(): for d in _floating_dtypes: yield asarray(1.0, dtype=d) - - BIG_INT = int(1e30) for op, dtypes in binary_op_dtypes.items(): ops = [op] if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]: @@ -139,40 +185,7 @@ def _array_vals(): for s in [1, 1.0, 1j, BIG_INT, False]: for _op in ops: for a in _array_vals(): - # Test array op scalar. From the spec, the following combinations - # are supported: - - # - Python bool for a bool array dtype, - # - a Python int within the bounds of the given dtype for integer array dtypes, - # - a Python int or float for real floating-point array dtypes - # - a Python int, float, or complex for complex floating-point array dtypes - - if ((dtypes == "all" - or dtypes == "numeric" and a.dtype in _numeric_dtypes - or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes - or dtypes == "integer" and a.dtype in _integer_dtypes - or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes - or dtypes == "boolean" and a.dtype in _boolean_dtypes - or dtypes == "floating" and a.dtype in _floating_dtypes - ) - # bool is a subtype of int, which is why we avoid - # isinstance here. - and (a.dtype in _boolean_dtypes and type(s) == bool - or a.dtype in _integer_dtypes and type(s) == int - or a.dtype in _real_floating_dtypes and type(s) in [float, int] - or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] - )): - if a.dtype in _integer_dtypes and s == BIG_INT: - assert_raises(OverflowError, lambda: getattr(a, _op)(s)) - else: - # Only test for no error - with suppress_warnings() as sup: - # ignore warnings from pow(BIG_INT) - sup.filter(RuntimeWarning, - "invalid value encountered in power") - getattr(a, _op)(s) - else: - assert_raises(TypeError, lambda: getattr(a, _op)(s)) + _check_op_array_scalar(dtypes, a, s, getattr(a, _op), _op) # Test array op array. for _op in ops: @@ -203,10 +216,10 @@ def _array_vals(): or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes) or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes - or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes + or dtypes == "integer or boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes) or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes - or dtypes == "floating" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes + or dtypes == "floating-point" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes ): getattr(x, _op)(y) else: @@ -214,7 +227,7 @@ def _array_vals(): unary_op_dtypes = { "__abs__": "numeric", - "__invert__": "integer_or_boolean", + "__invert__": "integer or boolean", "__neg__": "numeric", "__pos__": "numeric", } @@ -223,7 +236,7 @@ def _array_vals(): if ( dtypes == "numeric" and a.dtype in _numeric_dtypes - or dtypes == "integer_or_boolean" + or dtypes == "integer or boolean" and a.dtype in _integer_or_boolean_dtypes ): # Only test for no error diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 4e1b9cc..0b90f0b 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,6 +1,7 @@ from inspect import signature, getmodule -from numpy.testing import assert_raises +from pytest import raises as assert_raises +from numpy.testing import suppress_warnings import pytest @@ -19,6 +20,8 @@ ) from .._flags import set_array_api_strict_flags +from .test_array_object import _check_op_array_scalar, BIG_INT + import array_api_strict @@ -120,6 +123,7 @@ def test_missing_functions(): # Ensure the above dictionary is complete. import array_api_strict._elementwise_functions as mod mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod] + mod_funcs = [n for n in mod_funcs if not n.startswith("_")] assert set(mod_funcs) == set(elementwise_function_input_types) @@ -202,3 +206,51 @@ def test_bitwise_shift_error(): assert_raises( ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) ) + + + +def test_scalars(): + # mirror test_array_object.py::test_operators() + # + # Also check that binary functions accept (array, scalar) and (scalar, array) + # arguments, and reject (scalar, scalar) arguments. + + # Use the latest version of the standard so that scalars are actually allowed + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version="2024.12") + + def _array_vals(): + for d in _integer_dtypes: + yield asarray(1, dtype=d) + for d in _boolean_dtypes: + yield asarray(False, dtype=d) + for d in _floating_dtypes: + yield asarray(1.0, dtype=d) + + + for func_name, dtypes in elementwise_function_input_types.items(): + func = getattr(_elementwise_functions, func_name) + if nargs(func) != 2: + continue + + for s in [1, 1.0, 1j, BIG_INT, False]: + for a in _array_vals(): + for func1 in [lambda s: func(a, s), lambda s: func(s, a)]: + allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name) + + # only check `func(array, scalar) == `func(array, array)` if + # the former is legal under the promotion rules + if allowed: + conv_scalar = asarray(s, dtype=a.dtype) + + with suppress_warnings() as sup: + # ignore warnings from pow(BIG_INT) + sup.filter(RuntimeWarning, + "invalid value encountered in power") + assert func(s, a) == func(conv_scalar, a) + assert func(a, s) == func(a, conv_scalar) + + with assert_raises(TypeError): + func(s, s) + + From 33055ce2b0d5560a6275f97214bef030ec00b260 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Dec 2024 16:13:59 +0200 Subject: [PATCH 212/252] add type annotations to binary functions --- array_api_strict/_elementwise_functions.py | 33 +++++++++++++++++----- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 3c4b3d8..54691d6 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -10,7 +10,7 @@ _real_numeric_dtypes, _numeric_dtypes, _result_type, - _dtype_categories as _dtype_dtype_categories, + _dtype_categories, ) from ._array_object import Array from ._flags import requires_api_version @@ -46,11 +46,26 @@ def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func): def create_binary_func(func_name, dtype_category, np_func): - def inner(x1: Array, x2: Array, /) -> Array: + def inner(x1, x2, /) -> Array: return _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func) return inner +# static type annotation for ArrayOrPythonScalar arguments given a category +# NB: keep the keys in sync with the _dtype_categories dict +_annotations = { + "all": "bool | int | float | complex | Array", + "real numeric": "int | float | Array", + "numeric": "int | float | complex | Array", + "integer": "int | Array", + "integer or boolean": "int | bool | Array", + "boolean": "bool | Array", + "real floating-point": "float | Array", + "complex floating-point": "complex | Array", + "floating-point": "float | complex | Array", +} + + # func_name: dtype_category (must match that from _dtypes.py) _binary_funcs = { "add": "numeric", @@ -97,7 +112,7 @@ def inner(x1: Array, x2: Array, /) -> Array: # create and attach functions to the module for func_name, dtype_category in _binary_funcs.items(): # sanity check - assert dtype_category in _dtype_dtype_categories + assert dtype_category in _dtype_categories numpy_name = _numpy_renames.get(func_name, func_name) np_func = getattr(np, numpy_name) @@ -106,6 +121,8 @@ def inner(x1: Array, x2: Array, /) -> Array: func.__name__ = func_name func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name) + func.__annotations__['x1'] = _annotations[dtype_category] + func.__annotations__['x2'] = _annotations[dtype_category] vars()[func_name] = func @@ -117,20 +134,22 @@ def inner(x1: Array, x2: Array, /) -> Array: nextafter = requires_api_version('2024.12')(nextafter) # noqa: F821 -def bitwise_left_shift(x1: Array, x2: Array, /) -> Array: +def bitwise_left_shift(x1: int | Array, x2: int | Array, /) -> Array: is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0 if is_negative: raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") return _bitwise_left_shift(x1, x2) # noqa: F821 -bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__ # noqa: F821 +if _bitwise_left_shift.__doc__: # noqa: F821 + bitwise_left_shift.__doc__ = _bitwise_left_shift.__doc__ # noqa: F821 -def bitwise_right_shift(x1: Array, x2: Array, /) -> Array: +def bitwise_right_shift(x1: int | Array, x2: int | Array, /) -> Array: is_negative = np.any(x2._array < 0) if isinstance(x2, Array) else x2 < 0 if is_negative: raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0") return _bitwise_right_shift(x1, x2) # noqa: F821 -bitwise_right_shift.__doc__ = _bitwise_right_shift.__doc__ # noqa: F821 +if _bitwise_right_shift.__doc__: # noqa: F821 + bitwise_right_shift.__doc__ = _bitwise_right_shift.__doc__ # noqa: F821 # clean up to not pollute the namespace From 035cf2da0cb0c492897d6c2f68ffdf282ec55931 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 24 Jan 2025 12:23:44 +0100 Subject: [PATCH 213/252] MAINT: undo the array_object.py change --- array_api_strict/_array_object.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 47153e5..a917441 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -230,8 +230,6 @@ def _check_device(self, other): elif isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") - else: - raise TypeError(f"Cannot combine an Array with {type(other)}.") # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar): From b9053c4c50f9a76d16281e73330d8a4ffd7f9053 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:56:57 +0000 Subject: [PATCH 214/252] Bump the actions group with 2 updates Bumps the actions group with 2 updates: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact) and [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish). Updates `dawidd6/action-download-artifact` from 7 to 8 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v7...v8) Updates `pypa/gh-action-pypi-publish` from 1.12.3 to 1.12.4 - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.12.3...v1.12.4) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-patch dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- .github/workflows/publish-package.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 4106b88..3700c17 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v7 + uses: dawidd6/action-download-artifact@v8 with: workflow: docs-build.yml name: docs-build diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 66c5cc6..05b2c49 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -97,14 +97,14 @@ jobs: # if: >- # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - # uses: pypa/gh-action-pypi-publish@v1.12.3 + # uses: pypa/gh-action-pypi-publish@v1.12.4 # with: # repository-url: https://test.pypi.org/legacy/ # print-hash: true - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@v1.12.3 + uses: pypa/gh-action-pypi-publish@v1.12.4 with: print-hash: true From 06a5351886689191ff4ae56dc3fcbc9ef154834f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 29 Jan 2025 00:12:51 +0100 Subject: [PATCH 215/252] ENH: allow 1j * float_array --- array_api_strict/_array_object.py | 12 +++++++++--- array_api_strict/_dtypes.py | 1 + 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index a917441..afee030 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -26,10 +26,12 @@ _integer_dtypes, _integer_or_boolean_dtypes, _floating_dtypes, + _real_floating_dtypes, _complex_floating_dtypes, _numeric_dtypes, _result_type, _dtype_categories, + _real_to_complex_map, ) from ._flags import get_array_api_strict_flags, set_array_api_strict_flags @@ -243,6 +245,7 @@ def _promote_scalar(self, scalar): """ from ._data_type_functions import iinfo + target_dtype = self.dtype # Note: Only Python scalar types that match the array dtype are # allowed. if isinstance(scalar, bool): @@ -268,10 +271,13 @@ def _promote_scalar(self, scalar): "Python float scalars can only be promoted with floating-point arrays." ) elif isinstance(scalar, complex): - if self.dtype not in _complex_floating_dtypes: + if self.dtype not in _floating_dtypes: raise TypeError( - "Python complex scalars can only be promoted with complex floating-point arrays." + "Python complex scalars can only be promoted with floating-point arrays." ) + # 1j * array(floating) is allowed + if self.dtype in _real_floating_dtypes: + target_dtype = _real_to_complex_map[self.dtype] else: raise TypeError("'scalar' must be a Python scalar") @@ -282,7 +288,7 @@ def _promote_scalar(self, scalar): # behavior for integers within the bounds of the integer dtype. # Outside of those bounds we use the default NumPy behavior (either # cast or raise OverflowError). - return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=self.device) + return Array._new(np.array(scalar, dtype=target_dtype._np_dtype), device=self.device) @staticmethod def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index b51ed92..66304dd 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -126,6 +126,7 @@ def __hash__(self): "floating-point": _floating_dtypes, } +_real_to_complex_map = {float32: complex64, float64: complex128} # Note: the spec defines a restricted type promotion table compared to NumPy. # In particular, cross-kind promotions like integer + float or boolean + From 2d9254db54dbacc3761de08c22ecbfeea7a3d9c8 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 29 Jan 2025 01:00:35 +0100 Subject: [PATCH 216/252] TST: add workarounds for funcs which do not support 1jfloat_array --- array_api_strict/tests/test_array_object.py | 10 +++++++++- .../tests/test_elementwise_functions.py | 14 ++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 4535d99..edfa073 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -108,6 +108,14 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT): # - a Python int or float for real floating-point array dtypes # - a Python int, float, or complex for complex floating-point array dtypes + # an exception: complex scalar floating array + scalar_types_for_float = [float, int] + if not (func_name.startswith("__i") + or (func_name in ["__floordiv__", "__rfloordiv__", "__mod__", "__rmod__"] + and type(s) == complex) + ): + scalar_types_for_float += [complex] + if ((dtypes == "all" or dtypes == "numeric" and a.dtype in _numeric_dtypes or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes @@ -121,7 +129,7 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT): # isinstance here. and (a.dtype in _boolean_dtypes and type(s) == bool or a.dtype in _integer_dtypes and type(s) == int - or a.dtype in _real_floating_dtypes and type(s) in [float, int] + or a.dtype in _real_floating_dtypes and type(s) in scalar_types_for_float or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] )): if a.dtype in _integer_dtypes and s == BIG_INT: diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 0b90f0b..cc3a2cd 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -233,15 +233,25 @@ def _array_vals(): if nargs(func) != 2: continue + nocomplex = [ + 'atan2', 'copysign', 'floor_divide', 'hypot', 'logaddexp', 'nextafter', + 'remainder', + 'greater', 'less', 'greater_equal', 'less_equal', 'maximum', 'minimum', + ] + for s in [1, 1.0, 1j, BIG_INT, False]: for a in _array_vals(): for func1 in [lambda s: func(a, s), lambda s: func(s, a)]: - allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name) + + if func_name in nocomplex and type(s) == complex: + allowed = False + else: + allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name) # only check `func(array, scalar) == `func(array, array)` if # the former is legal under the promotion rules if allowed: - conv_scalar = asarray(s, dtype=a.dtype) + conv_scalar = a._promote_scalar(s) with suppress_warnings() as sup: # ignore warnings from pow(BIG_INT) From fe9760e78c8cee4bb0bd981f034bed8c71118b46 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 1 Feb 2025 15:20:33 +0100 Subject: [PATCH 217/252] ENH: add dtype kwarg to fft.{fftfreq, rfftfreq} --- array_api_strict/_fft.py | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/array_api_strict/_fft.py b/array_api_strict/_fft.py index 4b0ceb6..c888826 100644 --- a/array_api_strict/_fft.py +++ b/array_api_strict/_fft.py @@ -4,7 +4,7 @@ if TYPE_CHECKING: from typing import Union, Optional, Literal - from ._typing import Device + from ._typing import Device, Dtype as DType from collections.abc import Sequence from ._dtypes import ( @@ -251,7 +251,14 @@ def ihfft( return res @requires_extension('fft') -def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: +def fftfreq( + n: int, + /, + *, + d: float = 1.0, + dtype: Optional[DType] = None, + device: Optional[Device] = None +) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.fftfreq `. @@ -259,10 +266,23 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar """ if device is not None and device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") - return Array._new(np.fft.fftfreq(n, d=d), device=device) + if dtype and not dtype in _real_floating_dtypes: + raise ValueError(f"`dtype` must be a real floating-point type. Got {dtype=}.") + + np_result = np.fft.fftfreq(n, d=d) + if dtype: + np_result = np_result.astype(dtype._np_dtype) + return Array._new(np_result, device=device) @requires_extension('fft') -def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: +def rfftfreq( + n: int, + /, + *, + d: float = 1.0, + dtype: Optional[DType] = None, + device: Optional[Device] = None +) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.rfftfreq `. @@ -270,7 +290,13 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A """ if device is not None and device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") - return Array._new(np.fft.rfftfreq(n, d=d), device=device) + if dtype and not dtype in _real_floating_dtypes: + raise ValueError(f"`dtype` must be a real floating-point type. Got {dtype=}.") + + np_result = np.fft.rfftfreq(n, d=d) + if dtype: + np_result = np_result.astype(dtype._np_dtype) + return Array._new(np_result, device=device) @requires_extension('fft') def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: From 590a2deb7899cb6f7a06f3ac68238ef28e69833d Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 3 Feb 2025 13:57:26 +0100 Subject: [PATCH 218/252] Add support for scalar arguments to xp.where (#78) Reviewed at https://github.com/data-apis/array-api-strict/pull/78 --- array_api_strict/_searching_functions.py | 24 +++++++++++++-- .../tests/test_searching_functions.py | 30 +++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) create mode 100644 array_api_strict/tests/test_searching_functions.py diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index df91e44..ad32aaa 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -2,7 +2,7 @@ from ._array_object import Array from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool -from ._flags import requires_data_dependent_shapes, requires_api_version +from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -90,12 +90,30 @@ def searchsorted( # x1 must be 1-D, but NumPy already requires this. return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device) -def where(condition: Array, x1: Array, x2: Array, /) -> Array: +def where( + condition: Array, + x1: bool | int | float | complex | Array, + x2: bool | int | float | complex | Array, / +) -> Array: """ Array API compatible wrapper for :py:func:`np.where `. See its docstring for more information. """ + if get_array_api_strict_flags()['api_version'] > '2023.12': + num_scalars = 0 + + if isinstance(x1, (bool, float, complex, int)): + x1 = Array._new(np.asarray(x1), device=condition.device) + num_scalars += 1 + + if isinstance(x2, (bool, float, complex, int)): + x2 = Array._new(np.asarray(x2), device=condition.device) + num_scalars += 1 + + if num_scalars == 2: + raise ValueError("One of x1, x2 arguments must be an array.") + # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) @@ -103,7 +121,7 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array: raise TypeError("`condition` must be have a boolean data type") if len({a.device for a in (condition, x1, x2)}) > 1: - raise ValueError("where inputs must all be on the same device") + raise ValueError("Inputs to `where` must all use the same device") x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np.where(condition._array, x1._array, x2._array), device=x1.device) diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py new file mode 100644 index 0000000..dfb3fe7 --- /dev/null +++ b/array_api_strict/tests/test_searching_functions.py @@ -0,0 +1,30 @@ +import pytest + +import array_api_strict as xp + +from array_api_strict import ArrayAPIStrictFlags +from array_api_strict._flags import draft_version + + +def test_where_with_scalars(): + x = xp.asarray([1, 2, 3, 1]) + + # Versions up to and including 2023.12 don't support scalar arguments + with pytest.raises(AttributeError, match="object has no attribute 'dtype'"): + xp.where(x == 1, 42, 44) + + # Versions after 2023.12 support scalar arguments + with (pytest.warns( + UserWarning, + match="The 2024.12 version of the array API specification is in draft status" + ), + ArrayAPIStrictFlags(api_version=draft_version), + ): + x_where = xp.where(x == 1, xp.asarray(42), 44) + + expected = xp.asarray([42, 44, 44, 42]) + assert xp.all(x_where == expected) + + # The spec does not allow both x1 and x2 to be scalars + with pytest.raises(ValueError, match="One of"): + xp.where(x == 1, 42, 44) From 50e8809ef397c8bece17b54e02442b01dee3c8b9 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 3 Feb 2025 16:11:30 +0100 Subject: [PATCH 219/252] BLD: use pyproject.toml not setup.py While at it, remove versioneer and use setuptools_scm for computing the version dynamically. This change is mainly because I did not manage to make versioneer work with pyproject.toml, while setuptools_scm "just worked" (well, nearly). --- .gitignore | 3 + array_api_strict/__init__.py | 9 +- array_api_strict/_version.py | 683 ---------- pyproject.toml | 33 + setup.cfg | 12 - setup.py | 29 - versioneer.py | 2277 ---------------------------------- 7 files changed, 42 insertions(+), 3004 deletions(-) delete mode 100644 array_api_strict/_version.py create mode 100644 pyproject.toml delete mode 100644 setup.cfg delete mode 100644 setup.py delete mode 100644 versioneer.py diff --git a/.gitignore b/.gitignore index f69e911..5feba5d 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,9 @@ share/python-wheels/ *.egg MANIFEST +*_version.py + + # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 27dec55..e6a1763 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -327,9 +327,12 @@ __all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayAPIStrictFlags'] -from . import _version -__version__ = _version.get_versions()['version'] -del _version +try: + from . import _version + __version__ = _version.__version__ + del _version +except ImportError: + __version__ = "unknown" # Extensions can be enabled or disabled dynamically. In order to make diff --git a/array_api_strict/_version.py b/array_api_strict/_version.py deleted file mode 100644 index 8218393..0000000 --- a/array_api_strict/_version.py +++ /dev/null @@ -1,683 +0,0 @@ - -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. -# Generated by versioneer-0.29 -# https://github.com/python-versioneer/python-versioneer - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys -from typing import Any, Callable, Dict, List, Optional, Tuple -import functools - - -def get_keywords() -> Dict[str, str]: - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - VCS: str - style: str - tag_prefix: str - parentdir_prefix: str - versionfile_source: str - verbose: bool - - -def get_config() -> VersioneerConfig: - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "" - cfg.parentdir_prefix = "" - cfg.versionfile_source = "array_api_strict/_version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} - - -def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator - """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f: Callable) -> Callable: - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command( - commands: List[str], - args: List[str], - cwd: Optional[str] = None, - verbose: bool = False, - hide_stderr: bool = False, - env: Optional[Dict[str, str]] = None, -) -> Tuple[Optional[str], Optional[int]]: - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - - popen_kwargs: Dict[str, Any] = {} - if sys.platform == "win32": - # This hides the console window if pythonw.exe is used - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - popen_kwargs["startupinfo"] = startupinfo - - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) - break - except OSError as e: - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, process.returncode - return stdout, process.returncode - - -def versions_from_parentdir( - parentdir_prefix: str, - root: str, - verbose: bool, -) -> Dict[str, Any]: - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords: Dict[str, str] = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords( - keywords: Dict[str, str], - tag_prefix: str, - verbose: bool, -) -> Dict[str, Any]: - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r'\d', r): - continue - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs( - tag_prefix: str, - root: str, - verbose: bool, - runner: Callable = run_command -) -> Dict[str, Any]: - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - # GIT_DIR can interfere with correct operation of Versioneer. - # It may be intended to be passed to the Versioneer-versioned project, - # but that should not change where we get our version from. - env = os.environ.copy() - env.pop("GIT_DIR", None) - runner = functools.partial(runner, env=env) - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces: Dict[str, Any] = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) - pieces["distance"] = len(out.split()) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces: Dict[str, Any]) -> str: - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces: Dict[str, Any]) -> str: - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces: Dict[str, Any]) -> str: - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: - """Split pep440 version string at the post-release segment. - - Returns the release segments before the post-release and the - post-release version number (or -1 if no post-release segment is present). - """ - vc = str.split(ver, ".post") - return vc[0], int(vc[1] or 0) if len(vc) == 2 else None - - -def render_pep440_pre(pieces: Dict[str, Any]) -> str: - """TAG[.postN.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - if pieces["distance"]: - # update the post release segment - tag_version, post_version = pep440_split_post(pieces["closest-tag"]) - rendered = tag_version - if post_version is not None: - rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) - else: - rendered += ".post0.dev%d" % (pieces["distance"]) - else: - # no commits, use the tag as the version - rendered = pieces["closest-tag"] - else: - # exception #1 - rendered = "0.post0.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces: Dict[str, Any]) -> str: - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces: Dict[str, Any]) -> str: - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces: Dict[str, Any]) -> str: - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces: Dict[str, Any]) -> str: - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions() -> Dict[str, Any]: - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..710f23e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,33 @@ +[build-system] +requires = ["setuptools >= 61.0", "setuptools_scm>8"] +build-backend = "setuptools.build_meta" + +[project] +name = "array_api_strict" +dynamic = ["version"] +requires-python = ">= 3.9" +dependencies = ["numpy"] +license = {file = "LICENSE"} +authors = [ + {name = "Consortium for Python Data API Standards"} +] +description = "A strict, minimal implementation of the Python array API standard." +readme = "README.md" +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", +] + +[project-urls] +Homepage = "https://data-apis.org/array-api-strict/" +Repository = "https://github.com/data-apis/array-api-strict" + +[tool.setuptools_scm] +version_file = "array_api_strict/_version.py" + diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index b551ef1..0000000 --- a/setup.cfg +++ /dev/null @@ -1,12 +0,0 @@ - -# See the docstring in versioneer.py for instructions. Note that you must -# re-run 'versioneer.py setup' after changing this section, and commit the -# resulting files. - -[versioneer] -VCS = git -style = pep440 -versionfile_source = array_api_strict/_version.py -versionfile_build = array_api_strict/_version.py -tag_prefix = -parentdir_prefix = diff --git a/setup.py b/setup.py deleted file mode 100644 index 29a94df..0000000 --- a/setup.py +++ /dev/null @@ -1,29 +0,0 @@ -from setuptools import setup, find_packages -import versioneer - -with open("README.md", "r") as fh: - long_description = fh.read() - -setup( - name='array_api_strict', - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - packages=find_packages(include=['array_api_strict*']), - author="Consortium for Python Data API Standards", - description="A strict, minimal implementation of the Python array API standard.", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://data-apis.org/array-api-strict/", - license="MIT", - python_requires=">=3.9", - install_requires=["numpy"], - classifiers=[ - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "License :: OSI Approved :: BSD License", - "Operating System :: OS Independent", - ], -) diff --git a/versioneer.py b/versioneer.py deleted file mode 100644 index 1e3753e..0000000 --- a/versioneer.py +++ /dev/null @@ -1,2277 +0,0 @@ - -# Version: 0.29 - -"""The Versioneer - like a rocketeer, but for versions. - -The Versioneer -============== - -* like a rocketeer, but for versions! -* https://github.com/python-versioneer/python-versioneer -* Brian Warner -* License: Public Domain (Unlicense) -* Compatible with: Python 3.7, 3.8, 3.9, 3.10, 3.11 and pypy3 -* [![Latest Version][pypi-image]][pypi-url] -* [![Build Status][travis-image]][travis-url] - -This is a tool for managing a recorded version number in setuptools-based -python projects. The goal is to remove the tedious and error-prone "update -the embedded version string" step from your release process. Making a new -release should be as easy as recording a new tag in your version-control -system, and maybe making new tarballs. - - -## Quick Install - -Versioneer provides two installation modes. The "classic" vendored mode installs -a copy of versioneer into your repository. The experimental build-time dependency mode -is intended to allow you to skip this step and simplify the process of upgrading. - -### Vendored mode - -* `pip install versioneer` to somewhere in your $PATH - * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is - available, so you can also use `conda install -c conda-forge versioneer` -* add a `[tool.versioneer]` section to your `pyproject.toml` or a - `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) - * Note that you will need to add `tomli; python_version < "3.11"` to your - build-time dependencies if you use `pyproject.toml` -* run `versioneer install --vendor` in your source tree, commit the results -* verify version information with `python setup.py version` - -### Build-time dependency mode - -* `pip install versioneer` to somewhere in your $PATH - * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is - available, so you can also use `conda install -c conda-forge versioneer` -* add a `[tool.versioneer]` section to your `pyproject.toml` or a - `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) -* add `versioneer` (with `[toml]` extra, if configuring in `pyproject.toml`) - to the `requires` key of the `build-system` table in `pyproject.toml`: - ```toml - [build-system] - requires = ["setuptools", "versioneer[toml]"] - build-backend = "setuptools.build_meta" - ``` -* run `versioneer install --no-vendor` in your source tree, commit the results -* verify version information with `python setup.py version` - -## Version Identifiers - -Source trees come from a variety of places: - -* a version-control system checkout (mostly used by developers) -* a nightly tarball, produced by build automation -* a snapshot tarball, produced by a web-based VCS browser, like github's - "tarball from tag" feature -* a release tarball, produced by "setup.py sdist", distributed through PyPI - -Within each source tree, the version identifier (either a string or a number, -this tool is format-agnostic) can come from a variety of places: - -* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows - about recent "tags" and an absolute revision-id -* the name of the directory into which the tarball was unpacked -* an expanded VCS keyword ($Id$, etc) -* a `_version.py` created by some earlier build step - -For released software, the version identifier is closely related to a VCS -tag. Some projects use tag names that include more than just the version -string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool -needs to strip the tag prefix to extract the version identifier. For -unreleased software (between tags), the version identifier should provide -enough information to help developers recreate the same tree, while also -giving them an idea of roughly how old the tree is (after version 1.2, before -version 1.3). Many VCS systems can report a description that captures this, -for example `git describe --tags --dirty --always` reports things like -"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the -0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes). - -The version identifier is used for multiple purposes: - -* to allow the module to self-identify its version: `myproject.__version__` -* to choose a name and prefix for a 'setup.py sdist' tarball - -## Theory of Operation - -Versioneer works by adding a special `_version.py` file into your source -tree, where your `__init__.py` can import it. This `_version.py` knows how to -dynamically ask the VCS tool for version information at import time. - -`_version.py` also contains `$Revision$` markers, and the installation -process marks `_version.py` to have this marker rewritten with a tag name -during the `git archive` command. As a result, generated tarballs will -contain enough information to get the proper version. - -To allow `setup.py` to compute a version too, a `versioneer.py` is added to -the top level of your source tree, next to `setup.py` and the `setup.cfg` -that configures it. This overrides several distutils/setuptools commands to -compute the version when invoked, and changes `setup.py build` and `setup.py -sdist` to replace `_version.py` with a small static file that contains just -the generated version data. - -## Installation - -See [INSTALL.md](./INSTALL.md) for detailed installation instructions. - -## Version-String Flavors - -Code which uses Versioneer can learn about its version string at runtime by -importing `_version` from your main `__init__.py` file and running the -`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can -import the top-level `versioneer.py` and run `get_versions()`. - -Both functions return a dictionary with different flavors of version -information: - -* `['version']`: A condensed version string, rendered using the selected - style. This is the most commonly used value for the project's version - string. The default "pep440" style yields strings like `0.11`, - `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section - below for alternative styles. - -* `['full-revisionid']`: detailed revision identifier. For Git, this is the - full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". - -* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the - commit date in ISO 8601 format. This will be None if the date is not - available. - -* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that - this is only accurate if run in a VCS checkout, otherwise it is likely to - be False or None - -* `['error']`: if the version string could not be computed, this will be set - to a string describing the problem, otherwise it will be None. It may be - useful to throw an exception in setup.py if this is set, to avoid e.g. - creating tarballs with a version string of "unknown". - -Some variants are more useful than others. Including `full-revisionid` in a -bug report should allow developers to reconstruct the exact code being tested -(or indicate the presence of local changes that should be shared with the -developers). `version` is suitable for display in an "about" box or a CLI -`--version` output: it can be easily compared against release notes and lists -of bugs fixed in various releases. - -The installer adds the following text to your `__init__.py` to place a basic -version in `YOURPROJECT.__version__`: - - from ._version import get_versions - __version__ = get_versions()['version'] - del get_versions - -## Styles - -The setup.cfg `style=` configuration controls how the VCS information is -rendered into a version string. - -The default style, "pep440", produces a PEP440-compliant string, equal to the -un-prefixed tag name for actual releases, and containing an additional "local -version" section with more detail for in-between builds. For Git, this is -TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags ---dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the -tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and -that this commit is two revisions ("+2") beyond the "0.11" tag. For released -software (exactly equal to a known tag), the identifier will only contain the -stripped tag, e.g. "0.11". - -Other styles are available. See [details.md](details.md) in the Versioneer -source tree for descriptions. - -## Debugging - -Versioneer tries to avoid fatal errors: if something goes wrong, it will tend -to return a version of "0+unknown". To investigate the problem, run `setup.py -version`, which will run the version-lookup code in a verbose mode, and will -display the full contents of `get_versions()` (including the `error` string, -which may help identify what went wrong). - -## Known Limitations - -Some situations are known to cause problems for Versioneer. This details the -most significant ones. More can be found on Github -[issues page](https://github.com/python-versioneer/python-versioneer/issues). - -### Subprojects - -Versioneer has limited support for source trees in which `setup.py` is not in -the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are -two common reasons why `setup.py` might not be in the root: - -* Source trees which contain multiple subprojects, such as - [Buildbot](https://github.com/buildbot/buildbot), which contains both - "master" and "slave" subprojects, each with their own `setup.py`, - `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI - distributions (and upload multiple independently-installable tarballs). -* Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other languages) in subdirectories. - -Versioneer will look for `.git` in parent directories, and most operations -should get the right version string. However `pip` and `setuptools` have bugs -and implementation details which frequently cause `pip install .` from a -subproject directory to fail to find a correct version string (so it usually -defaults to `0+unknown`). - -`pip install --editable .` should work correctly. `setup.py install` might -work too. - -Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in -some later version. - -[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking -this issue. The discussion in -[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the -issue from the Versioneer side in more detail. -[pip PR#3176](https://github.com/pypa/pip/pull/3176) and -[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve -pip to let Versioneer work correctly. - -Versioneer-0.16 and earlier only looked for a `.git` directory next to the -`setup.cfg`, so subprojects were completely unsupported with those releases. - -### Editable installs with setuptools <= 18.5 - -`setup.py develop` and `pip install --editable .` allow you to install a -project into a virtualenv once, then continue editing the source code (and -test) without re-installing after every change. - -"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a -convenient way to specify executable scripts that should be installed along -with the python package. - -These both work as expected when using modern setuptools. When using -setuptools-18.5 or earlier, however, certain operations will cause -`pkg_resources.DistributionNotFound` errors when running the entrypoint -script, which must be resolved by re-installing the package. This happens -when the install happens with one version, then the egg_info data is -regenerated while a different version is checked out. Many setup.py commands -cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into -a different virtualenv), so this can be surprising. - -[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes -this one, but upgrading to a newer version of setuptools should probably -resolve it. - - -## Updating Versioneer - -To upgrade your project to a new release of Versioneer, do the following: - -* install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg` and `pyproject.toml`, if necessary, - to include any new configuration settings indicated by the release notes. - See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install --[no-]vendor` in your source tree, to replace - `SRC/_version.py` -* commit any changed files - -## Future Directions - -This tool is designed to make it easily extended to other version-control -systems: all VCS-specific components are in separate directories like -src/git/ . The top-level `versioneer.py` script is assembled from these -components by running make-versioneer.py . In the future, make-versioneer.py -will take a VCS name as an argument, and will construct a version of -`versioneer.py` that is specific to the given VCS. It might also take the -configuration arguments that are currently provided manually during -installation by editing setup.py . Alternatively, it might go the other -direction and include code from all supported VCS systems, reducing the -number of intermediate scripts. - -## Similar projects - -* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time - dependency -* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of - versioneer -* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools - plugin - -## License - -To make Versioneer easier to embed, all its code is dedicated to the public -domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the "Unlicense", as described in -https://unlicense.org/. - -[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg -[pypi-url]: https://pypi.python.org/pypi/versioneer/ -[travis-image]: -https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg -[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer - -""" -# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring -# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements -# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error -# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with -# pylint:disable=attribute-defined-outside-init,too-many-arguments - -import configparser -import errno -import json -import os -import re -import subprocess -import sys -from pathlib import Path -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union -from typing import NoReturn -import functools - -have_tomllib = True -if sys.version_info >= (3, 11): - import tomllib -else: - try: - import tomli as tomllib - except ImportError: - have_tomllib = False - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - VCS: str - style: str - tag_prefix: str - versionfile_source: str - versionfile_build: Optional[str] - parentdir_prefix: Optional[str] - verbose: Optional[bool] - - -def get_root() -> str: - """Get the project root directory. - - We require that all commands are run from the project root, i.e. the - directory that contains setup.py, setup.cfg, and versioneer.py . - """ - root = os.path.realpath(os.path.abspath(os.getcwd())) - setup_py = os.path.join(root, "setup.py") - pyproject_toml = os.path.join(root, "pyproject.toml") - versioneer_py = os.path.join(root, "versioneer.py") - if not ( - os.path.exists(setup_py) - or os.path.exists(pyproject_toml) - or os.path.exists(versioneer_py) - ): - # allow 'python path/to/setup.py COMMAND' - root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) - setup_py = os.path.join(root, "setup.py") - pyproject_toml = os.path.join(root, "pyproject.toml") - versioneer_py = os.path.join(root, "versioneer.py") - if not ( - os.path.exists(setup_py) - or os.path.exists(pyproject_toml) - or os.path.exists(versioneer_py) - ): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") - raise VersioneerBadRootError(err) - try: - # Certain runtime workflows (setup.py install/develop in a setuptools - # tree) execute all dependencies in a single python process, so - # "versioneer" may be imported multiple times, and python's shared - # module-import table will cache the first one. So we can't use - # os.path.dirname(__file__), as that will find whichever - # versioneer.py was first imported, even in later projects. - my_path = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(my_path)[0]) - vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir and "VERSIONEER_PEP518" not in globals(): - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(my_path), versioneer_py)) - except NameError: - pass - return root - - -def get_config_from_root(root: str) -> VersioneerConfig: - """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise OSError (if setup.cfg is missing), or - # configparser.NoSectionError (if it lacks a [versioneer] section), or - # configparser.NoOptionError (if it lacks "VCS="). See the docstring at - # the top of versioneer.py for instructions on writing your setup.cfg . - root_pth = Path(root) - pyproject_toml = root_pth / "pyproject.toml" - setup_cfg = root_pth / "setup.cfg" - section: Union[Dict[str, Any], configparser.SectionProxy, None] = None - if pyproject_toml.exists() and have_tomllib: - try: - with open(pyproject_toml, 'rb') as fobj: - pp = tomllib.load(fobj) - section = pp['tool']['versioneer'] - except (tomllib.TOMLDecodeError, KeyError) as e: - print(f"Failed to load config from {pyproject_toml}: {e}") - print("Try to load it from setup.cfg") - if not section: - parser = configparser.ConfigParser() - with open(setup_cfg) as cfg_file: - parser.read_file(cfg_file) - parser.get("versioneer", "VCS") # raise error if missing - - section = parser["versioneer"] - - # `cast`` really shouldn't be used, but its simplest for the - # common VersioneerConfig users at the moment. We verify against - # `None` values elsewhere where it matters - - cfg = VersioneerConfig() - cfg.VCS = section['VCS'] - cfg.style = section.get("style", "") - cfg.versionfile_source = cast(str, section.get("versionfile_source")) - cfg.versionfile_build = section.get("versionfile_build") - cfg.tag_prefix = cast(str, section.get("tag_prefix")) - if cfg.tag_prefix in ("''", '""', None): - cfg.tag_prefix = "" - cfg.parentdir_prefix = section.get("parentdir_prefix") - if isinstance(section, configparser.SectionProxy): - # Make sure configparser translates to bool - cfg.verbose = section.getboolean("verbose") - else: - cfg.verbose = section.get("verbose") - - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -# these dictionaries contain VCS-specific tools -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} - - -def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator - """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f: Callable) -> Callable: - """Store f in HANDLERS[vcs][method].""" - HANDLERS.setdefault(vcs, {})[method] = f - return f - return decorate - - -def run_command( - commands: List[str], - args: List[str], - cwd: Optional[str] = None, - verbose: bool = False, - hide_stderr: bool = False, - env: Optional[Dict[str, str]] = None, -) -> Tuple[Optional[str], Optional[int]]: - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - - popen_kwargs: Dict[str, Any] = {} - if sys.platform == "win32": - # This hides the console window if pythonw.exe is used - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - popen_kwargs["startupinfo"] = startupinfo - - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) - break - except OSError as e: - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, process.returncode - return stdout, process.returncode - - -LONG_VERSION_PY['git'] = r''' -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. -# Generated by versioneer-0.29 -# https://github.com/python-versioneer/python-versioneer - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys -from typing import Any, Callable, Dict, List, Optional, Tuple -import functools - - -def get_keywords() -> Dict[str, str]: - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" - git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" - git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - VCS: str - style: str - tag_prefix: str - parentdir_prefix: str - versionfile_source: str - verbose: bool - - -def get_config() -> VersioneerConfig: - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "%(STYLE)s" - cfg.tag_prefix = "%(TAG_PREFIX)s" - cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" - cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} - - -def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator - """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f: Callable) -> Callable: - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command( - commands: List[str], - args: List[str], - cwd: Optional[str] = None, - verbose: bool = False, - hide_stderr: bool = False, - env: Optional[Dict[str, str]] = None, -) -> Tuple[Optional[str], Optional[int]]: - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - - popen_kwargs: Dict[str, Any] = {} - if sys.platform == "win32": - # This hides the console window if pythonw.exe is used - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - popen_kwargs["startupinfo"] = startupinfo - - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) - break - except OSError as e: - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %%s" %% dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %%s" %% (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %%s (error)" %% dispcmd) - print("stdout was %%s" %% stdout) - return None, process.returncode - return stdout, process.returncode - - -def versions_from_parentdir( - parentdir_prefix: str, - root: str, - verbose: bool, -) -> Dict[str, Any]: - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %%s but none started with prefix %%s" %% - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords: Dict[str, str] = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords( - keywords: Dict[str, str], - tag_prefix: str, - verbose: bool, -) -> Dict[str, Any]: - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %%d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} - if verbose: - print("discarding '%%s', no digits" %% ",".join(refs - tags)) - if verbose: - print("likely tags: %%s" %% ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r'\d', r): - continue - if verbose: - print("picking %%s" %% r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs( - tag_prefix: str, - root: str, - verbose: bool, - runner: Callable = run_command -) -> Dict[str, Any]: - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - # GIT_DIR can interfere with correct operation of Versioneer. - # It may be intended to be passed to the Versioneer-versioned project, - # but that should not change where we get our version from. - env = os.environ.copy() - env.pop("GIT_DIR", None) - runner = functools.partial(runner, env=env) - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) - if rc != 0: - if verbose: - print("Directory %%s not under git control" %% root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces: Dict[str, Any] = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%%s'" - %% describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%%s' doesn't start with prefix '%%s'" - print(fmt %% (full_tag, tag_prefix)) - pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" - %% (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) - pieces["distance"] = len(out.split()) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces: Dict[str, Any]) -> str: - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces: Dict[str, Any]) -> str: - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces: Dict[str, Any]) -> str: - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: - """Split pep440 version string at the post-release segment. - - Returns the release segments before the post-release and the - post-release version number (or -1 if no post-release segment is present). - """ - vc = str.split(ver, ".post") - return vc[0], int(vc[1] or 0) if len(vc) == 2 else None - - -def render_pep440_pre(pieces: Dict[str, Any]) -> str: - """TAG[.postN.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - if pieces["distance"]: - # update the post release segment - tag_version, post_version = pep440_split_post(pieces["closest-tag"]) - rendered = tag_version - if post_version is not None: - rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"]) - else: - rendered += ".post0.dev%%d" %% (pieces["distance"]) - else: - # no commits, use the tag as the version - rendered = pieces["closest-tag"] - else: - # exception #1 - rendered = "0.post0.dev%%d" %% pieces["distance"] - return rendered - - -def render_pep440_post(pieces: Dict[str, Any]) -> str: - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces: Dict[str, Any]) -> str: - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces: Dict[str, Any]) -> str: - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces: Dict[str, Any]) -> str: - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%%s'" %% style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions() -> Dict[str, Any]: - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} -''' - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords: Dict[str, str] = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords( - keywords: Dict[str, str], - tag_prefix: str, - verbose: bool, -) -> Dict[str, Any]: - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r'\d', r): - continue - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs( - tag_prefix: str, - root: str, - verbose: bool, - runner: Callable = run_command -) -> Dict[str, Any]: - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - # GIT_DIR can interfere with correct operation of Versioneer. - # It may be intended to be passed to the Versioneer-versioned project, - # but that should not change where we get our version from. - env = os.environ.copy() - env.pop("GIT_DIR", None) - runner = functools.partial(runner, env=env) - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces: Dict[str, Any] = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) - pieces["distance"] = len(out.split()) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def do_vcs_install(versionfile_source: str, ipy: Optional[str]) -> None: - """Git-specific installation logic for Versioneer. - - For Git, this means creating/changing .gitattributes to mark _version.py - for export-subst keyword substitution. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - files = [versionfile_source] - if ipy: - files.append(ipy) - if "VERSIONEER_PEP518" not in globals(): - try: - my_path = __file__ - if my_path.endswith((".pyc", ".pyo")): - my_path = os.path.splitext(my_path)[0] + ".py" - versioneer_file = os.path.relpath(my_path) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) - present = False - try: - with open(".gitattributes", "r") as fobj: - for line in fobj: - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - break - except OSError: - pass - if not present: - with open(".gitattributes", "a+") as fobj: - fobj.write(f"{versionfile_source} export-subst\n") - files.append(".gitattributes") - run_command(GITS, ["add", "--"] + files) - - -def versions_from_parentdir( - parentdir_prefix: str, - root: str, - verbose: bool, -) -> Dict[str, Any]: - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.29) from -# revision-control system data, or from the parent directory name of an -# unpacked source archive. Distribution tarballs contain a pre-generated copy -# of this file. - -import json - -version_json = ''' -%s -''' # END VERSION_JSON - - -def get_versions(): - return json.loads(version_json) -""" - - -def versions_from_file(filename: str) -> Dict[str, Any]: - """Try to determine the version from _version.py if present.""" - try: - with open(filename) as f: - contents = f.read() - except OSError: - raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - raise NotThisMethod("no version_json in _version.py") - return json.loads(mo.group(1)) - - -def write_to_version_file(filename: str, versions: Dict[str, Any]) -> None: - """Write the given version number to the given _version.py file.""" - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) - with open(filename, "w") as f: - f.write(SHORT_VERSION_PY % contents) - - print("set %s to '%s'" % (filename, versions["version"])) - - -def plus_or_dot(pieces: Dict[str, Any]) -> str: - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces: Dict[str, Any]) -> str: - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces: Dict[str, Any]) -> str: - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: - """Split pep440 version string at the post-release segment. - - Returns the release segments before the post-release and the - post-release version number (or -1 if no post-release segment is present). - """ - vc = str.split(ver, ".post") - return vc[0], int(vc[1] or 0) if len(vc) == 2 else None - - -def render_pep440_pre(pieces: Dict[str, Any]) -> str: - """TAG[.postN.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - if pieces["distance"]: - # update the post release segment - tag_version, post_version = pep440_split_post(pieces["closest-tag"]) - rendered = tag_version - if post_version is not None: - rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) - else: - rendered += ".post0.dev%d" % (pieces["distance"]) - else: - # no commits, use the tag as the version - rendered = pieces["closest-tag"] - else: - # exception #1 - rendered = "0.post0.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces: Dict[str, Any]) -> str: - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces: Dict[str, Any]) -> str: - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces: Dict[str, Any]) -> str: - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces: Dict[str, Any]) -> str: - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -class VersioneerBadRootError(Exception): - """The project root directory is unknown or missing key files.""" - - -def get_versions(verbose: bool = False) -> Dict[str, Any]: - """Get the project version from whatever source is available. - - Returns dict with two keys: 'version' and 'full'. - """ - if "versioneer" in sys.modules: - # see the discussion in cmdclass.py:get_cmdclass() - del sys.modules["versioneer"] - - root = get_root() - cfg = get_config_from_root(root) - - assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" - handlers = HANDLERS.get(cfg.VCS) - assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or bool(cfg.verbose) # `bool()` used to avoid `None` - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" - assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" - - versionfile_abs = os.path.join(root, cfg.versionfile_source) - - # extract version from first of: _version.py, VCS command (e.g. 'git - # describe'), parentdir. This is meant to work for developers using a - # source checkout, for users of a tarball created by 'setup.py sdist', - # and for users of a tarball/zipball created by 'git archive' or github's - # download-from-tag feature or the equivalent in other VCSes. - - get_keywords_f = handlers.get("get_keywords") - from_keywords_f = handlers.get("keywords") - if get_keywords_f and from_keywords_f: - try: - keywords = get_keywords_f(versionfile_abs) - ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) - if verbose: - print("got version from expanded keyword %s" % ver) - return ver - except NotThisMethod: - pass - - try: - ver = versions_from_file(versionfile_abs) - if verbose: - print("got version from file %s %s" % (versionfile_abs, ver)) - return ver - except NotThisMethod: - pass - - from_vcs_f = handlers.get("pieces_from_vcs") - if from_vcs_f: - try: - pieces = from_vcs_f(cfg.tag_prefix, root, verbose) - ver = render(pieces, cfg.style) - if verbose: - print("got version from VCS %s" % ver) - return ver - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - if verbose: - print("got version from parentdir %s" % ver) - return ver - except NotThisMethod: - pass - - if verbose: - print("unable to compute version") - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} - - -def get_version() -> str: - """Get the short version string for this project.""" - return get_versions()["version"] - - -def get_cmdclass(cmdclass: Optional[Dict[str, Any]] = None): - """Get the custom setuptools subclasses used by Versioneer. - - If the package uses a different cmdclass (e.g. one from numpy), it - should be provide as an argument. - """ - if "versioneer" in sys.modules: - del sys.modules["versioneer"] - # this fixes the "python setup.py develop" case (also 'install' and - # 'easy_install .'), in which subdependencies of the main project are - # built (using setup.py bdist_egg) in the same python process. Assume - # a main project A and a dependency B, which use different versions - # of Versioneer. A's setup.py imports A's Versioneer, leaving it in - # sys.modules by the time B's setup.py is executed, causing B to run - # with the wrong versioneer. Setuptools wraps the sub-dep builds in a - # sandbox that restores sys.modules to it's pre-build state, so the - # parent is protected against the child's "import versioneer". By - # removing ourselves from sys.modules here, before the child build - # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/python-versioneer/python-versioneer/issues/52 - - cmds = {} if cmdclass is None else cmdclass.copy() - - # we add "version" to setuptools - from setuptools import Command - - class cmd_version(Command): - description = "report generated version string" - user_options: List[Tuple[str, str, str]] = [] - boolean_options: List[str] = [] - - def initialize_options(self) -> None: - pass - - def finalize_options(self) -> None: - pass - - def run(self) -> None: - vers = get_versions(verbose=True) - print("Version: %s" % vers["version"]) - print(" full-revisionid: %s" % vers.get("full-revisionid")) - print(" dirty: %s" % vers.get("dirty")) - print(" date: %s" % vers.get("date")) - if vers["error"]: - print(" error: %s" % vers["error"]) - cmds["version"] = cmd_version - - # we override "build_py" in setuptools - # - # most invocation pathways end up running build_py: - # distutils/build -> build_py - # distutils/install -> distutils/build ->.. - # setuptools/bdist_wheel -> distutils/install ->.. - # setuptools/bdist_egg -> distutils/install_lib -> build_py - # setuptools/install -> bdist_egg ->.. - # setuptools/develop -> ? - # pip install: - # copies source tree to a tempdir before running egg_info/etc - # if .git isn't copied too, 'git describe' will fail - # then does setup.py bdist_wheel, or sometimes setup.py install - # setup.py egg_info -> ? - - # pip install -e . and setuptool/editable_wheel will invoke build_py - # but the build_py command is not expected to copy any files. - - # we override different "build_py" commands for both environments - if 'build_py' in cmds: - _build_py: Any = cmds['build_py'] - else: - from setuptools.command.build_py import build_py as _build_py - - class cmd_build_py(_build_py): - def run(self) -> None: - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_py.run(self) - if getattr(self, "editable_mode", False): - # During editable installs `.py` and data files are - # not copied to build_lib - return - # now locate _version.py in the new build/ directory and replace - # it with an updated value - if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - cmds["build_py"] = cmd_build_py - - if 'build_ext' in cmds: - _build_ext: Any = cmds['build_ext'] - else: - from setuptools.command.build_ext import build_ext as _build_ext - - class cmd_build_ext(_build_ext): - def run(self) -> None: - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_ext.run(self) - if self.inplace: - # build_ext --inplace will only build extensions in - # build/lib<..> dir with no _version.py to write to. - # As in place builds will already have a _version.py - # in the module dir, we do not need to write one. - return - # now locate _version.py in the new build/ directory and replace - # it with an updated value - if not cfg.versionfile_build: - return - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) - if not os.path.exists(target_versionfile): - print(f"Warning: {target_versionfile} does not exist, skipping " - "version update. This can happen if you are running build_ext " - "without first running build_py.") - return - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - cmds["build_ext"] = cmd_build_ext - - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe # type: ignore - # nczeczulin reports that py2exe won't like the pep440-style string - # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. - # setup(console=[{ - # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION - # "product_version": versioneer.get_version(), - # ... - - class cmd_build_exe(_build_exe): - def run(self) -> None: - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _build_exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["build_exe"] = cmd_build_exe - del cmds["build_py"] - - if 'py2exe' in sys.modules: # py2exe enabled? - try: - from py2exe.setuptools_buildexe import py2exe as _py2exe # type: ignore - except ImportError: - from py2exe.distutils_buildexe import py2exe as _py2exe # type: ignore - - class cmd_py2exe(_py2exe): - def run(self) -> None: - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _py2exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["py2exe"] = cmd_py2exe - - # sdist farms its file list building out to egg_info - if 'egg_info' in cmds: - _egg_info: Any = cmds['egg_info'] - else: - from setuptools.command.egg_info import egg_info as _egg_info - - class cmd_egg_info(_egg_info): - def find_sources(self) -> None: - # egg_info.find_sources builds the manifest list and writes it - # in one shot - super().find_sources() - - # Modify the filelist and normalize it - root = get_root() - cfg = get_config_from_root(root) - self.filelist.append('versioneer.py') - if cfg.versionfile_source: - # There are rare cases where versionfile_source might not be - # included by default, so we must be explicit - self.filelist.append(cfg.versionfile_source) - self.filelist.sort() - self.filelist.remove_duplicates() - - # The write method is hidden in the manifest_maker instance that - # generated the filelist and was thrown away - # We will instead replicate their final normalization (to unicode, - # and POSIX-style paths) - from setuptools import unicode_utils - normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/') - for f in self.filelist.files] - - manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt') - with open(manifest_filename, 'w') as fobj: - fobj.write('\n'.join(normalized)) - - cmds['egg_info'] = cmd_egg_info - - # we override different "sdist" commands for both environments - if 'sdist' in cmds: - _sdist: Any = cmds['sdist'] - else: - from setuptools.command.sdist import sdist as _sdist - - class cmd_sdist(_sdist): - def run(self) -> None: - versions = get_versions() - self._versioneer_generated_versions = versions - # unless we update this, the command will keep using the old - # version - self.distribution.metadata.version = versions["version"] - return _sdist.run(self) - - def make_release_tree(self, base_dir: str, files: List[str]) -> None: - root = get_root() - cfg = get_config_from_root(root) - _sdist.make_release_tree(self, base_dir, files) - # now locate _version.py in the new base_dir directory - # (remembering that it may be a hardlink) and replace it with an - # updated value - target_versionfile = os.path.join(base_dir, cfg.versionfile_source) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) - cmds["sdist"] = cmd_sdist - - return cmds - - -CONFIG_ERROR = """ -setup.cfg is missing the necessary Versioneer configuration. You need -a section like: - - [versioneer] - VCS = git - style = pep440 - versionfile_source = src/myproject/_version.py - versionfile_build = myproject/_version.py - tag_prefix = - parentdir_prefix = myproject- - -You will also need to edit your setup.py to use the results: - - import versioneer - setup(version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), ...) - -Please read the docstring in ./versioneer.py for configuration instructions, -edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. -""" - -SAMPLE_CONFIG = """ -# See the docstring in versioneer.py for instructions. Note that you must -# re-run 'versioneer.py setup' after changing this section, and commit the -# resulting files. - -[versioneer] -#VCS = git -#style = pep440 -#versionfile_source = -#versionfile_build = -#tag_prefix = -#parentdir_prefix = - -""" - -OLD_SNIPPET = """ -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions -""" - -INIT_PY_SNIPPET = """ -from . import {0} -__version__ = {0}.get_versions()['version'] -""" - - -def do_setup() -> int: - """Do main VCS-independent setup function for installing Versioneer.""" - root = get_root() - try: - cfg = get_config_from_root(root) - except (OSError, configparser.NoSectionError, - configparser.NoOptionError) as e: - if isinstance(e, (OSError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) - with open(os.path.join(root, "setup.cfg"), "a") as f: - f.write(SAMPLE_CONFIG) - print(CONFIG_ERROR, file=sys.stderr) - return 1 - - print(" creating %s" % cfg.versionfile_source) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") - maybe_ipy: Optional[str] = ipy - if os.path.exists(ipy): - try: - with open(ipy, "r") as f: - old = f.read() - except OSError: - old = "" - module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] - snippet = INIT_PY_SNIPPET.format(module) - if OLD_SNIPPET in old: - print(" replacing boilerplate in %s" % ipy) - with open(ipy, "w") as f: - f.write(old.replace(OLD_SNIPPET, snippet)) - elif snippet not in old: - print(" appending to %s" % ipy) - with open(ipy, "a") as f: - f.write(snippet) - else: - print(" %s unmodified" % ipy) - else: - print(" %s doesn't exist, ok" % ipy) - maybe_ipy = None - - # Make VCS-specific changes. For git, this means creating/changing - # .gitattributes to mark _version.py for export-subst keyword - # substitution. - do_vcs_install(cfg.versionfile_source, maybe_ipy) - return 0 - - -def scan_setup_py() -> int: - """Validate the contents of setup.py against Versioneer's expectations.""" - found = set() - setters = False - errors = 0 - with open("setup.py", "r") as f: - for line in f.readlines(): - if "import versioneer" in line: - found.add("import") - if "versioneer.get_cmdclass()" in line: - found.add("cmdclass") - if "versioneer.get_version()" in line: - found.add("get_version") - if "versioneer.VCS" in line: - setters = True - if "versioneer.versionfile_source" in line: - setters = True - if len(found) != 3: - print("") - print("Your setup.py appears to be missing some important items") - print("(but I might be wrong). Please make sure it has something") - print("roughly like the following:") - print("") - print(" import versioneer") - print(" setup( version=versioneer.get_version(),") - print(" cmdclass=versioneer.get_cmdclass(), ...)") - print("") - errors += 1 - if setters: - print("You should remove lines like 'versioneer.VCS = ' and") - print("'versioneer.versionfile_source = ' . This configuration") - print("now lives in setup.cfg, and should be removed from setup.py") - print("") - errors += 1 - return errors - - -def setup_command() -> NoReturn: - """Set up Versioneer and exit with appropriate error code.""" - errors = do_setup() - errors += scan_setup_py() - sys.exit(1 if errors else 0) - - -if __name__ == "__main__": - cmd = sys.argv[1] - if cmd == "setup": - setup_command() From e3e01c8ef9f5193fd3958e636200cfe3b54ad0b8 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 5 Feb 2025 10:44:32 +0100 Subject: [PATCH 220/252] MAINT: block axis=None in squeeze closes https://github.com/data-apis/array-api-strict/issues/62 --- array_api_strict/_manipulation_functions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index d775835..63c3516 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -153,6 +153,11 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: See its docstring for more information. """ + if axis is None: + raise ValueError( + "squeeze(..., axis=None is not supported. See " + "https://github.com/data-apis/array-api/pull/100 for a discussion." + ) return Array._new(np.squeeze(x._array, axis=axis), device=x.device) From 7e9397044fda395ff0b1e2090d4981f41aa0a568 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 7 Feb 2025 11:03:40 +0100 Subject: [PATCH 221/252] ENH: real and conj accept numeric dtypes --- array_api_strict/_elementwise_functions.py | 8 ++++---- array_api_strict/tests/test_elementwise_functions.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 54691d6..c11b17c 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -378,8 +378,8 @@ def conj(x: Array, /) -> Array: See its docstring for more information. """ - if x.dtype not in _complex_floating_dtypes: - raise TypeError("Only complex floating-point dtypes are allowed in conj") + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in conj") return Array._new(np.conj(x._array), device=x.device) @@ -568,8 +568,8 @@ def real(x: Array, /) -> Array: See its docstring for more information. """ - if x.dtype not in _complex_floating_dtypes: - raise TypeError("Only complex floating-point dtypes are allowed in real") + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in real") return Array._new(np.real(x._array), device=x.device) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 0b90f0b..f38cdb9 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -52,7 +52,7 @@ def nargs(func): "bitwise_xor": "integer or boolean", "ceil": "real numeric", "clip": "real numeric", - "conj": "complex floating-point", + "conj": "numeric", "copysign": "real floating-point", "cos": "floating-point", "cosh": "floating-point", @@ -88,7 +88,7 @@ def nargs(func): "not_equal": "all", "positive": "numeric", "pow": "numeric", - "real": "complex floating-point", + "real": "numeric", "reciprocal": "floating-point", "remainder": "real numeric", "round": "numeric", From a8f93759a34f7b15107878d2de71c6d549d44bc0 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Feb 2025 15:08:18 +0100 Subject: [PATCH 222/252] TST: fix tests for disallowed array indexing Several assertions were raising for "wrong" reasons. --- array_api_strict/tests/test_array_object.py | 49 +++++++++++---------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index edfa073..87fdbc3 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -45,35 +45,37 @@ def test_validate_index(): a = ones((3, 4)) # Out of bounds slices are not allowed - assert_raises(IndexError, lambda: a[:4]) - assert_raises(IndexError, lambda: a[:-4]) - assert_raises(IndexError, lambda: a[:3:-1]) - assert_raises(IndexError, lambda: a[:-5:-1]) - assert_raises(IndexError, lambda: a[4:]) - assert_raises(IndexError, lambda: a[-4:]) - assert_raises(IndexError, lambda: a[4::-1]) - assert_raises(IndexError, lambda: a[-4::-1]) - - assert_raises(IndexError, lambda: a[...,:5]) - assert_raises(IndexError, lambda: a[...,:-5]) - assert_raises(IndexError, lambda: a[...,:5:-1]) - assert_raises(IndexError, lambda: a[...,:-6:-1]) - assert_raises(IndexError, lambda: a[...,5:]) - assert_raises(IndexError, lambda: a[...,-5:]) - assert_raises(IndexError, lambda: a[...,5::-1]) - assert_raises(IndexError, lambda: a[...,-5::-1]) + assert_raises(IndexError, lambda: a[:4, 0]) + assert_raises(IndexError, lambda: a[:-4, 0]) + assert_raises(IndexError, lambda: a[:3:-1]) # XXX raises for a wrong reason + assert_raises(IndexError, lambda: a[:-5:-1, 0]) + assert_raises(IndexError, lambda: a[4:, 0]) + assert_raises(IndexError, lambda: a[-4:, 0]) + assert_raises(IndexError, lambda: a[4::-1, 0]) + assert_raises(IndexError, lambda: a[-4::-1, 0]) + + assert_raises(IndexError, lambda: a[..., :5]) + assert_raises(IndexError, lambda: a[..., :-5]) + assert_raises(IndexError, lambda: a[..., :5:-1]) + assert_raises(IndexError, lambda: a[..., :-6:-1]) + assert_raises(IndexError, lambda: a[..., 5:]) + assert_raises(IndexError, lambda: a[..., -5:]) + assert_raises(IndexError, lambda: a[..., 5::-1]) + assert_raises(IndexError, lambda: a[..., -5::-1]) # Boolean indices cannot be part of a larger tuple index - assert_raises(IndexError, lambda: a[a[:,0]==1,0]) - assert_raises(IndexError, lambda: a[a[:,0]==1,...]) - assert_raises(IndexError, lambda: a[..., a[0]==1]) + assert_raises(IndexError, lambda: a[a[:, 0] == 1, 0]) + assert_raises(IndexError, lambda: a[a[:, 0] == 1, ...]) + assert_raises(IndexError, lambda: a[..., a[0] == 1]) assert_raises(IndexError, lambda: a[[True, True, True]]) assert_raises(IndexError, lambda: a[(True, True, True),]) # Integer array indices are not allowed (except for 0-D) - idx = asarray([[0, 1]]) - assert_raises(IndexError, lambda: a[idx]) - assert_raises(IndexError, lambda: a[idx,]) + idx = asarray([0, 1]) + assert_raises(IndexError, lambda: a[idx, 0]) + assert_raises(IndexError, lambda: a[0, idx]) + + # Array-likes (lists, tuples) are not allowed as indices assert_raises(IndexError, lambda: a[[0, 1]]) assert_raises(IndexError, lambda: a[(0, 1), (0, 1)]) assert_raises(IndexError, lambda: a[[0, 1]]) @@ -87,6 +89,7 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[0,]) assert_raises(IndexError, lambda: a[0]) assert_raises(IndexError, lambda: a[:]) + assert_raises(IndexError, lambda: a[idx]) def test_promoted_scalar_inherits_device(): device1 = Device("device1") From c514dbc8f4b96de357fd0f387a7259b590a1e4a4 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Feb 2025 16:26:50 +0100 Subject: [PATCH 223/252] ENH: allow 1D integer array indices --- array_api_strict/_array_object.py | 20 ++++++-- array_api_strict/tests/test_array_object.py | 51 +++++++++++++++++++-- 2 files changed, 64 insertions(+), 7 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index afee030..6d63577 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -395,6 +395,8 @@ def _validate_index(self, key): single_axes = [] n_ellipsis = 0 key_has_mask = False + key_has_index_array = False + key_has_slices = False for i in _key: if i is not None: nonexpanding_key.append(i) @@ -403,6 +405,8 @@ def _validate_index(self, key): if isinstance(i, Array): if i.dtype in _boolean_dtypes: key_has_mask = True + elif i.dtype in _integer_dtypes: + key_has_index_array = True single_axes.append(i) else: # i must not be an array here, to avoid elementwise equals @@ -410,6 +414,8 @@ def _validate_index(self, key): n_ellipsis += 1 else: single_axes.append(i) + if isinstance(i, slice): + key_has_slices = True n_single_axes = len(single_axes) if n_ellipsis > 1: @@ -427,6 +433,12 @@ def _validate_index(self, key): "specified in the Array API." ) + if (key_has_index_array and (n_ellipsis > 0 or key_has_slices or key_has_mask)): + raise IndexError( + "Integer index arrays are only allowed with integer indices; " + f"got {key}." + ) + if n_ellipsis == 0: indexed_shape = self.shape else: @@ -485,11 +497,11 @@ def _validate_index(self, key): if not get_array_api_strict_flags()['boolean_indexing']: raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict") - elif i.dtype in _integer_dtypes and i.ndim != 0: + elif i.dtype in _integer_dtypes and i.ndim > 1: raise IndexError( - f"Single-axes index {i} is a non-zero-dimensional " - "integer array, but advanced integer indexing is not " - "specified in the Array API." + f"Single-axes index {i} is a multi-dimensional " + "integer array, but advanced integer indexing is only " + "specified in the Array API for 1D index arrays." ) elif isinstance(i, tuple): raise IndexError( diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 87fdbc3..5747cb1 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from .. import ones, asarray, result_type, all, equal +from .. import ones, arange, reshape, asarray, result_type, all, equal from .._array_object import Array, CPU_DEVICE, Device from .._dtypes import ( _all_dtypes, @@ -70,11 +70,25 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[[True, True, True]]) assert_raises(IndexError, lambda: a[(True, True, True),]) - # Integer array indices are not allowed (except for 0-D) - idx = asarray([0, 1]) + # Integer array indices are not allowed (except for 0-D or 1D) + idx = asarray([[0, 1]]) # idx.ndim == 2 assert_raises(IndexError, lambda: a[idx, 0]) assert_raises(IndexError, lambda: a[0, idx]) + # Mixing 1D integer array indices with slices, ellipsis or booleans is not allowed + idx = asarray([0, 1]) + assert_raises(IndexError, lambda: a[..., idx]) + assert_raises(IndexError, lambda: a[:, idx]) + assert_raises(IndexError, lambda: a[asarray([True, True]), idx]) + + # 1D integer array indices must have the same length + idx1 = asarray([0, 1]) + idx2 = asarray([0, 1, 1]) + assert_raises(IndexError, lambda: a[idx1, idx2]) + + # Non-integer array indices are not allowed + assert_raises(IndexError, lambda: a[ones(2), 0]) + # Array-likes (lists, tuples) are not allowed as indices assert_raises(IndexError, lambda: a[[0, 1]]) assert_raises(IndexError, lambda: a[(0, 1), (0, 1)]) @@ -91,6 +105,37 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[:]) assert_raises(IndexError, lambda: a[idx]) + +def test_indexing_arrays(): + # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed + + # 1D array + a = arange(5) + idx = asarray([1, 0, 1, 2, -1]) + a_idx = a[idx] + + a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])]) + assert all(a_idx == a_idx_loop) + + # setitem with arrays is not allowed # XXX + # with assert_raises(IndexError): + # a[idx] = 42 + + # mixed array and integer indexing + a = reshape(arange(3*4), (3, 4)) + idx = asarray([1, 0, 1, 2, -1]) + a_idx = a[idx, 1] + + a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])]) + assert all(a_idx == a_idx_loop) + + + # index with two arrays + a_idx = a[idx, idx] + a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])]) + assert all(a_idx == a_idx_loop) + + def test_promoted_scalar_inherits_device(): device1 = Device("device1") x = asarray([1., 2, 3], device=device1) From 36a370ada23e25206d3398a60d77c0d5a08e0636 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 18 Feb 2025 11:17:45 +0100 Subject: [PATCH 224/252] ENH: fancy indexing __setitem__ is not allowed --- array_api_strict/_array_object.py | 7 +++++-- array_api_strict/tests/test_array_object.py | 11 +++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 6d63577..1a8c566 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -327,7 +327,7 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: # Note: A large fraction of allowed indices are disallowed here (see the # docstring below) - def _validate_index(self, key): + def _validate_index(self, key, op="getitem"): """ Validate an index according to the array API. @@ -390,6 +390,9 @@ def _validate_index(self, key): "zero-dimensional integer arrays and boolean arrays " "are specified in the Array API." ) + if op == "setitem": + if isinstance(i, Array) and i.dtype in _integer_dtypes: + raise IndexError("Fancy indexing __setitem__ is not supported.") nonexpanding_key = [] single_axes = [] @@ -914,7 +917,7 @@ def __setitem__( """ # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index - self._validate_index(key) + self._validate_index(key, op="setitem") if isinstance(key, Array): # Indexing self._array with array_api_strict arrays can be erroneous key = key._array diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 5747cb1..6a381d4 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -117,9 +117,9 @@ def test_indexing_arrays(): a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) - # setitem with arrays is not allowed # XXX - # with assert_raises(IndexError): - # a[idx] = 42 + # setitem with arrays is not allowed + with assert_raises(IndexError): + a[idx] = 42 # mixed array and integer indexing a = reshape(arange(3*4), (3, 4)) @@ -129,12 +129,15 @@ def test_indexing_arrays(): a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) - # index with two arrays a_idx = a[idx, idx] a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) + # setitem with arrays is not allowed + with assert_raises(IndexError): + a[idx, idx] = 42 + def test_promoted_scalar_inherits_device(): device1 = Device("device1") From 6664e6d241ca4fce8305821dd6a7ed143b5796c0 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 18 Feb 2025 20:04:30 +0100 Subject: [PATCH 225/252] ENH: allow ndim>1 indexing arrays --- array_api_strict/_array_object.py | 11 +++-------- array_api_strict/tests/test_array_object.py | 9 ++++----- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 1a8c566..0595594 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -498,14 +498,9 @@ def _validate_index(self, key, op="getitem"): "Array API when the array is the sole index." ) if not get_array_api_strict_flags()['boolean_indexing']: - raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict") - - elif i.dtype in _integer_dtypes and i.ndim > 1: - raise IndexError( - f"Single-axes index {i} is a multi-dimensional " - "integer array, but advanced integer indexing is only " - "specified in the Array API for 1D index arrays." - ) + raise RuntimeError( + "The boolean_indexing flag has been disabled for array-api-strict" + ) elif isinstance(i, tuple): raise IndexError( f"Single-axes index {i} is a tuple, but nested tuple " diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 6a381d4..ef76c28 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -70,11 +70,6 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[[True, True, True]]) assert_raises(IndexError, lambda: a[(True, True, True),]) - # Integer array indices are not allowed (except for 0-D or 1D) - idx = asarray([[0, 1]]) # idx.ndim == 2 - assert_raises(IndexError, lambda: a[idx, 0]) - assert_raises(IndexError, lambda: a[0, idx]) - # Mixing 1D integer array indices with slices, ellipsis or booleans is not allowed idx = asarray([0, 1]) assert_raises(IndexError, lambda: a[..., idx]) @@ -138,6 +133,10 @@ def test_indexing_arrays(): with assert_raises(IndexError): a[idx, idx] = 42 + # smoke test indexing with ndim > 1 arrays + idx = idx[..., None] + a[idx, idx] + def test_promoted_scalar_inherits_device(): device1 = Device("device1") From fc8e7314a55a6d1207bc6ae1ea5b17aa8fdb363e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 24 Feb 2025 10:36:31 +0100 Subject: [PATCH 226/252] ENH: set the default version to 2024.12 And adapt test_flags accordingly. --- array_api_strict/_flags.py | 5 ++-- array_api_strict/tests/test_flags.py | 45 +++++++++++++++++----------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 279b0e7..3fce8a0 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -22,11 +22,12 @@ "2021.12", "2022.12", "2023.12", + "2024.12" ) -draft_version = "2024.12" +draft_version = "2025.12" -API_VERSION = default_version = "2023.12" +API_VERSION = default_version = "2024.12" BOOLEAN_INDEXING = True diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index dcfc20d..764ca77 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -19,7 +19,7 @@ def test_flag_defaults(): flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2023.12', + 'api_version': '2024.12', 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), @@ -36,7 +36,7 @@ def test_reset_flags(): reset_array_api_strict_flags() flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2023.12', + 'api_version': '2024.12', 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), @@ -47,7 +47,7 @@ def test_setting_flags(): set_array_api_strict_flags(data_dependent_shapes=False) flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2023.12', + 'api_version': '2024.12', 'boolean_indexing': True, 'data_dependent_shapes': False, 'enabled_extensions': ('linalg', 'fft'), @@ -55,7 +55,7 @@ def test_setting_flags(): set_array_api_strict_flags(enabled_extensions=('fft',)) flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2023.12', + 'api_version': '2024.12', 'boolean_indexing': True, 'data_dependent_shapes': False, 'enabled_extensions': ('fft',), @@ -98,15 +98,26 @@ def test_flags_api_version_2023_12(): } def test_flags_api_version_2024_12(): - # Make sure setting the version to 2024.12 issues a warning. + set_array_api_strict_flags(api_version='2024.12') + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2024.12', + 'boolean_indexing': True, + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } + + +def test_flags_api_version_2025_12(): + # Make sure setting the version to 2025.12 issues a warning. with pytest.warns(UserWarning) as record: - set_array_api_strict_flags(api_version='2024.12') + set_array_api_strict_flags(api_version='2025.12') assert len(record) == 1 - assert '2024.12' in str(record[0].message) + assert '2025.12' in str(record[0].message) assert 'draft' in str(record[0].message) flags = get_array_api_strict_flags() assert flags == { - 'api_version': '2024.12', + 'api_version': '2025.12', 'boolean_indexing': True, 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), @@ -125,9 +136,12 @@ def test_setting_flags_invalid(): def test_api_version(): # Test defaults - assert xp.__array_api_version__ == '2023.12' + assert xp.__array_api_version__ == '2024.12' # Test setting the version + set_array_api_strict_flags(api_version='2023.12') + assert xp.__array_api_version__ == '2023.12' + set_array_api_strict_flags(api_version='2022.12') assert xp.__array_api_version__ == '2022.12' @@ -315,8 +329,8 @@ def test_api_version_2023_12(func_name): def test_api_version_2024_12(func_name): func = api_version_2024_12_examples[func_name] - # By default, these functions should error - pytest.raises(RuntimeError, func) + # By default, these functions should not error + func() # In 2022.12 and 2023.12, these functions should error set_array_api_strict_flags(api_version='2022.12') @@ -324,11 +338,6 @@ def test_api_version_2024_12(func_name): set_array_api_strict_flags(api_version='2023.12') pytest.raises(RuntimeError, func) - # They should not error in 2024.12 - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2024.12') - func() - # Test the behavior gets updated properly set_array_api_strict_flags(api_version='2023.12') pytest.raises(RuntimeError, func) @@ -435,9 +444,9 @@ def test_environment_variables(): # ARRAY_API_STRICT_API_VERSION ('''\ import array_api_strict as xp -assert xp.__array_api_version__ == '2023.12' +assert xp.__array_api_version__ == '2024.12' -assert xp.get_array_api_strict_flags()['api_version'] == '2023.12' +assert xp.get_array_api_strict_flags()['api_version'] == '2024.12' ''', {}), *[ From 59a9ce726b94c8a832e9a0396f71025a40f2a371 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 24 Feb 2025 10:58:58 +0100 Subject: [PATCH 227/252] TST: update tests for 2024.12 revision (warnings, defaults) 2024.12 features do not emit warnings by default. --- array_api_strict/tests/test_array_object.py | 11 ++++---- .../tests/test_data_type_functions.py | 13 ++++----- .../tests/test_elementwise_functions.py | 10 ++----- .../tests/test_searching_functions.py | 28 ++++++++----------- .../tests/test_statistical_functions.py | 10 +++---- 5 files changed, 30 insertions(+), 42 deletions(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index ef76c28..e24a40f 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -509,10 +509,10 @@ def test_array_keys_use_private_array(): def test_array_namespace(): a = ones((3, 3)) assert a.__array_namespace__() == array_api_strict - assert array_api_strict.__array_api_version__ == "2023.12" + assert array_api_strict.__array_api_version__ == "2024.12" assert a.__array_namespace__(api_version=None) is array_api_strict - assert array_api_strict.__array_api_version__ == "2023.12" + assert array_api_strict.__array_api_version__ == "2024.12" assert a.__array_namespace__(api_version="2022.12") is array_api_strict assert array_api_strict.__array_api_version__ == "2022.12" @@ -525,11 +525,12 @@ def test_array_namespace(): assert array_api_strict.__array_api_version__ == "2021.12" with pytest.warns(UserWarning): - assert a.__array_namespace__(api_version="2024.12") is array_api_strict - assert array_api_strict.__array_api_version__ == "2024.12" + assert a.__array_namespace__(api_version="2025.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2025.12" + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) - pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2025.12")) + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2026.12")) def test_iter(): pytest.raises(TypeError, lambda: iter(asarray(3))) diff --git a/array_api_strict/tests/test_data_type_functions.py b/array_api_strict/tests/test_data_type_functions.py index 863d3d4..919c0b4 100644 --- a/array_api_strict/tests/test_data_type_functions.py +++ b/array_api_strict/tests/test_data_type_functions.py @@ -80,12 +80,11 @@ def test_result_type_py_scalars(api_version): with pytest.raises(TypeError): result_type(int16, 3) else: - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version=api_version) + set_array_api_strict_flags(api_version=api_version) - assert result_type(int8, 3) == int8 - assert result_type(uint8, 3) == uint8 - assert result_type(float64, 3) == float64 + assert result_type(int8, 3) == int8 + assert result_type(uint8, 3) == uint8 + assert result_type(float64, 3) == float64 - with pytest.raises(TypeError): - result_type(int64, True) + with pytest.raises(TypeError): + result_type(int64, True) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 93078ed..99596b4 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -3,7 +3,6 @@ from pytest import raises as assert_raises from numpy.testing import suppress_warnings -import pytest from .. import asarray, _elementwise_functions from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift @@ -134,8 +133,7 @@ def _array_vals(dtypes): yield asarray(1., dtype=d) # Use the latest version of the standard so all functions are included - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version="2024.12") + set_array_api_strict_flags(api_version="2024.12") for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] @@ -171,8 +169,7 @@ def _array_vals(): yield asarray(1.0, dtype=d) # Use the latest version of the standard so all functions are included - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version="2024.12") + set_array_api_strict_flags(api_version="2024.12") for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): @@ -216,8 +213,7 @@ def test_scalars(): # arguments, and reject (scalar, scalar) arguments. # Use the latest version of the standard so that scalars are actually allowed - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version="2024.12") + set_array_api_strict_flags(api_version="2024.12") def _array_vals(): for d in _integer_dtypes: diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index dfb3fe7..0e54d5f 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -3,28 +3,22 @@ import array_api_strict as xp from array_api_strict import ArrayAPIStrictFlags -from array_api_strict._flags import draft_version def test_where_with_scalars(): x = xp.asarray([1, 2, 3, 1]) # Versions up to and including 2023.12 don't support scalar arguments - with pytest.raises(AttributeError, match="object has no attribute 'dtype'"): - xp.where(x == 1, 42, 44) + with ArrayAPIStrictFlags(api_version='2023.12'): + with pytest.raises(AttributeError, match="object has no attribute 'dtype'"): + xp.where(x == 1, 42, 44) # Versions after 2023.12 support scalar arguments - with (pytest.warns( - UserWarning, - match="The 2024.12 version of the array API specification is in draft status" - ), - ArrayAPIStrictFlags(api_version=draft_version), - ): - x_where = xp.where(x == 1, xp.asarray(42), 44) - - expected = xp.asarray([42, 44, 44, 42]) - assert xp.all(x_where == expected) - - # The spec does not allow both x1 and x2 to be scalars - with pytest.raises(ValueError, match="One of"): - xp.where(x == 1, 42, 44) + x_where = xp.where(x == 1, xp.asarray(42), 44) + + expected = xp.asarray([42, 44, 44, 42]) + assert xp.all(x_where == expected) + + # The spec does not allow both x1 and x2 to be scalars + with pytest.raises(ValueError, match="One of"): + xp.where(x == 1, 42, 44) diff --git a/array_api_strict/tests/test_statistical_functions.py b/array_api_strict/tests/test_statistical_functions.py index c97670d..d702b17 100644 --- a/array_api_strict/tests/test_statistical_functions.py +++ b/array_api_strict/tests/test_statistical_functions.py @@ -1,7 +1,7 @@ import cmath import pytest -from .._flags import set_array_api_strict_flags +from .._flags import set_array_api_strict_flags, ArrayAPIStrictFlags import array_api_strict as xp @@ -44,12 +44,10 @@ def test_sum_prod_trace_2023_12(func_name): def test_mean_complex(): a = xp.asarray([1j, 2j, 3j]) - set_array_api_strict_flags(api_version='2023.12') - with pytest.raises(TypeError): - xp.mean(a) + with ArrayAPIStrictFlags(api_version='2023.12'): + with pytest.raises(TypeError): + xp.mean(a) - with pytest.warns(UserWarning): - set_array_api_strict_flags(api_version='2024.12') m = xp.mean(a) assert cmath.isclose(complex(m), 2j) From a86d0bfe4b7b70ddc8f141b8ccf2c1cacde58bf2 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 2 Feb 2025 15:01:15 +0100 Subject: [PATCH 228/252] DOC: update the changelog for the 2.3 release --- docs/changelog.md | 68 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index d33dc24..7f6be2c 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,73 @@ # Changelog +## 2.3 (2025-XX-XX) + +### Major Changes + +- The default version of the array API standard is now 2024.12. Previous versions can + still be enabled via the [flags API](array-api-strict-flags). + + Note that this support is still relatively untested. Please [report any + issues](https://github.com/data-apis/array-api-strict/issues) you find. + +- Binary elementwise functions now accept python scalars: the only requirement is that + at least one of the arguments must be an array; the other argument may be either + a python scalar or an array. Python scalars are handled in accordance with the + type promotion rules, as specified by the standard. + This change unifies the behavior of binary functions and their matching operators, + (where available), such as `multiply(x1, x2)` and `__mul__(self, other)`. + + `where` accepts arrays or scalars as its 2nd and 3rd arguments, `x1` and `x2`. + The first argument, `condition`, must be an array. + + `result_type` accepts arrays and scalars and computes the result dtype according + to the promotion rules. + +- Ergonomics of working with complex values has been improved: + + - binary operators accept complex scalars and real arrays and preserve the floating point + precision: `1j*f32_array` returns a `complex64` array + - `mean` accepts complex floating-point arrays. + - `real` and `conj` accept numeric arguments, including real floating point data. + Note that `imag` still requires its input to be a complex array. + +- The following functions, new in the 2024.12 standard revision, are implemented: + + - `count_nonzero` + - `cumulative_prod` + +- `fftfreq` and `rfftfreq` functions accept a new `dtype` argument to control the + the data type of their output. + + +### Minor Changes + +- `vecdot` now conjugates the first argument, in accordance with the standard. + +- `astype` now raises a `TypeError` instead of casting a complex floating-point + array to a real-valued or an integral data type. + +- `where` requires that its first argument, `condition` has a boolean data dtype, + and raises a `TypeError` otherwise. + +- `isdtype` raises a `TypeError` is its argument is not a dtype object. + +- arrays created with `from_dlpack` now correctly set their `device` attribute. + +- the build system now uses `pyproject.toml`, not `setup.py`. + +### Contributors + +The following users contributed to this release: + +Aaron Meurer +Clément Robert +Guido Imperiale +Evgeni Burovski +Lucas Colley +Tim Head + + ## 2.2 (2024-11-11) ### Major Changes From 5d2d42f92f84ac747f0667df9b3fe6894073c6e9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 Mar 2025 21:57:11 +0000 Subject: [PATCH 229/252] Bump dawidd6/action-download-artifact from 8 to 9 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 8 to 9 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v8...v9) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 3700c17..fc61258 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v8 + uses: dawidd6/action-download-artifact@v9 with: workflow: docs-build.yml name: docs-build From bde69fd950ff273ad24d9f2c115aafee66e57b65 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 6 Mar 2025 11:27:08 +0100 Subject: [PATCH 230/252] BUG: fix where(cond, float_array, int) This should be allowed by the spec: python int scalars can combine with float arrays. --- array_api_strict/_helpers.py | 2 +- array_api_strict/_searching_functions.py | 14 ++------------ .../tests/test_searching_functions.py | 15 ++++++++++++++- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py index 2258d29..d3fc9c9 100644 --- a/array_api_strict/_helpers.py +++ b/array_api_strict/_helpers.py @@ -31,7 +31,7 @@ def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name): x2 = x1._promote_scalar(x2) else: if x1.dtype not in _allowed_dtypes or x2.dtype not in _allowed_dtypes: - raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. " + raise TypeError(f"Only {dtype_category} dtypes are allowed in {func_name}(...). " f"Got {x1.dtype} and {x2.dtype}.") return x1, x2 diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index ad32aaa..9864132 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -3,6 +3,7 @@ from ._array_object import Array from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags +from ._helpers import _maybe_normalize_py_scalars from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -101,18 +102,7 @@ def where( See its docstring for more information. """ if get_array_api_strict_flags()['api_version'] > '2023.12': - num_scalars = 0 - - if isinstance(x1, (bool, float, complex, int)): - x1 = Array._new(np.asarray(x1), device=condition.device) - num_scalars += 1 - - if isinstance(x2, (bool, float, complex, int)): - x2 = Array._new(np.asarray(x2), device=condition.device) - num_scalars += 1 - - if num_scalars == 2: - raise ValueError("One of x1, x2 arguments must be an array.") + x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index 0e54d5f..016862c 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -20,5 +20,18 @@ def test_where_with_scalars(): assert xp.all(x_where == expected) # The spec does not allow both x1 and x2 to be scalars - with pytest.raises(ValueError, match="One of"): + with pytest.raises(TypeError, match="Two scalars"): xp.where(x == 1, 42, 44) + + +def test_where_mixed_dtypes(): + # https://github.com/data-apis/array-api-strict/issues/131 + x = xp.asarray([1., 2.]) + res = xp.where(x > 1.5, x, 0) + assert res.dtype == x.dtype + assert all(res == xp.asarray([0., 2.])) + + # retry with boolean x1, x2 + c = x > 1.5 + res = xp.where(c, False, c) + assert all(res == xp.asarray([False, False])) From 709f4febd201829777c667ee2fe090912e65981b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 14 Mar 2025 10:11:46 +0100 Subject: [PATCH 231/252] TST: add a regression test for `where` Check that mixing scalars with arrays preserves the dtype --- array_api_strict/tests/test_searching_functions.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index 016862c..2a3a79e 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -35,3 +35,10 @@ def test_where_mixed_dtypes(): c = x > 1.5 res = xp.where(c, False, c) assert all(res == xp.asarray([False, False])) + + +def test_where_f32(): + # https://github.com/data-apis/array-api-strict/issues/131#issuecomment-2723016294 + res = xp.where(xp.asarray([True, False]), 1., xp.asarray([2, 2], dtype=xp.float32)) + assert res.dtype == xp.float32 + From ea0d0b89320d488cf8775c494195632b2f9d0842 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 10:40:19 +0100 Subject: [PATCH 232/252] DOC: update the changelog for 2.3.1 --- docs/changelog.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/changelog.md b/docs/changelog.md index 7f6be2c..5c60162 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,6 +1,13 @@ # Changelog -## 2.3 (2025-XX-XX) +## 2.3.1 (2025-03-20) + +This is a bugfix release with no new features compared to 2.3. This release fixes an +issue with `where` for scalar arguments, found in downstream testing of the 2024.12 +support. + + +## 2.3 (2025-02-27) ### Major Changes From d7175193b7d133a817f431b410803a121fc0ce0a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 11:20:12 +0100 Subject: [PATCH 233/252] CI: test 2024.12 on CI --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 9f168cb..0b2ce1d 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -4,7 +4,7 @@ on: [push, pull_request] env: PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline --max-examples 200" - API_VERSIONS: "2022.12 2023.12" + API_VERSIONS: "2022.12 2023.12 2024.12" jobs: array-api-tests: From aeebf69d6e4b4dc539b67c485827e18d494ce695 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 11:36:07 +0100 Subject: [PATCH 234/252] BLD: upper cap setuptools --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 710f23e..b3d2594 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools >= 61.0", "setuptools_scm>8"] +requires = ["setuptools >= 61.0,<=75", "setuptools_scm>8"] build-backend = "setuptools.build_meta" [project] From 1dfe06e2bfc206ecf1dfc4c91e6d88299f3e170d Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 11:35:51 +0100 Subject: [PATCH 235/252] TST: add skips for weird special cases sync array-api-strict/array_api_tests-xfails.txt and array-api-tests/array_api_strict-skipss.txt --- array-api-tests-xfails.txt | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/array-api-tests-xfails.txt b/array-api-tests-xfails.txt index 68c7fdb..a6919dd 100644 --- a/array-api-tests-xfails.txt +++ b/array-api-tests-xfails.txt @@ -3,6 +3,26 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] +# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, all libraries do just that +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + # The test suite is incorrectly checking sums that have loss of significance # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum From 1aa4fd5de54f4a05998675164f502dc51b077473 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 2 Mar 2025 18:34:51 +0100 Subject: [PATCH 236/252] CI: update the workflows - stop testing python 3.9 - start testing python 3.13 - actually test Array API revision 2024.12 --- .github/workflows/array-api-tests.yml | 6 +++--- .github/workflows/tests.yml | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 0b2ce1d..815b82b 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -11,11 +11,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] + python-version: ['3.10', '3.11', '3.12', '3.13'] numpy-version: ['1.26', 'dev'] exclude: - - python-version: '3.8' - numpy-version: 'dev' + - python-version: '3.13' + numpy-version: '1.26' steps: - name: Checkout array-api-strict diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d8124d4..703e6e7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -5,11 +5,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] + python-version: ['3.10', '3.11', '3.12', '3.13'] numpy-version: ['1.26', 'dev'] exclude: - - python-version: '3.8' - numpy-version: 'dev' + - python-version: '3.13' + numpy-version: '1.26' fail-fast: true steps: - uses: actions/checkout@v4 From ae0378626ce9ceefdd378b7c4dcc157f57f852c8 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 2 Mar 2025 19:16:45 +0100 Subject: [PATCH 237/252] TST: update xfails for the spec 2024.12 revision --- .github/workflows/array-api-tests.yml | 4 ++-- array-api-tests-xfails.txt | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 815b82b..114f42d 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -3,8 +3,8 @@ name: Array API Tests on: [push, pull_request] env: - PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline --max-examples 200" - API_VERSIONS: "2022.12 2023.12 2024.12" + PYTEST_ARGS: "-v -rxXfE --hypothesis-disable-deadline --max-examples 200" + API_VERSIONS: "2023.12 2024.12" jobs: array-api-tests: diff --git a/array-api-tests-xfails.txt b/array-api-tests-xfails.txt index a6919dd..f68095b 100644 --- a/array-api-tests-xfails.txt +++ b/array-api-tests-xfails.txt @@ -26,3 +26,6 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and # The test suite is incorrectly checking sums that have loss of significance # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_sum + +array_api_tests/test_special_cases.py::test_nan_propagation[cumulative_prod] + From 34ce3f64d95eee40fa742076154bd30cc9537c1d Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 31 Mar 2025 23:32:21 +0100 Subject: [PATCH 238/252] TYP: type annotations (#135) * TYP: type annotations * Python 3.9 fixes * self-review * code review * Update array_api_strict/_array_object.py Co-authored-by: Joren Hammudoglu * Apply suggestions from code review Co-authored-by: Joren Hammudoglu * fix * code review * fixes * normalize order * Fancy indexing in `__getitem__` signature * verbose Python scalar types --------- Co-authored-by: Joren Hammudoglu --- array_api_strict/__init__.py | 16 +- array_api_strict/_array_object.py | 295 ++++++++++-------- array_api_strict/_constants.py | 2 +- array_api_strict/_creation_functions.py | 205 ++++++------ array_api_strict/_data_type_functions.py | 95 +++--- array_api_strict/_dtypes.py | 59 ++-- array_api_strict/_elementwise_functions.py | 47 +-- array_api_strict/_fft.py | 58 ++-- array_api_strict/_flags.py | 97 ++++-- array_api_strict/_helpers.py | 19 +- array_api_strict/_indexing_functions.py | 11 +- array_api_strict/_info.py | 131 ++++---- array_api_strict/_linalg.py | 108 ++++--- array_api_strict/_linear_algebra_functions.py | 21 +- array_api_strict/_manipulation_functions.py | 54 ++-- array_api_strict/_searching_functions.py | 41 +-- array_api_strict/_set_functions.py | 8 +- array_api_strict/_sorting_functions.py | 12 +- array_api_strict/_statistical_functions.py | 121 ++++--- array_api_strict/_typing.py | 80 ++--- array_api_strict/_utility_functions.py | 24 +- array_api_strict/py.typed | 0 array_api_strict/tests/test_validation.py | 4 +- pyproject.toml | 15 + 24 files changed, 807 insertions(+), 716 deletions(-) create mode 100644 array_api_strict/py.typed diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index e6a1763..116df25 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -16,6 +16,8 @@ """ +from types import ModuleType + __all__ = [] # Warning: __array_api_version__ could change globally with @@ -325,12 +327,16 @@ ArrayAPIStrictFlags, ) -__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayAPIStrictFlags'] +__all__ += [ + 'set_array_api_strict_flags', + 'get_array_api_strict_flags', + 'reset_array_api_strict_flags', + 'ArrayAPIStrictFlags', + '__version__', +] try: - from . import _version - __version__ = _version.__version__ - del _version + from ._version import __version__ # type: ignore[import-not-found,unused-ignore] except ImportError: __version__ = "unknown" @@ -340,7 +346,7 @@ # use __getattr__. Note that linalg and fft are dynamically added and removed # from __all__ in set_array_api_strict_flags. -def __getattr__(name): +def __getattr__(name: str) -> ModuleType: if name in ['linalg', 'fft']: if name in get_array_api_strict_flags()['enabled_extensions']: if name == 'linalg': diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 0595594..1304d5a 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -16,62 +16,70 @@ from __future__ import annotations import operator +import sys +from collections.abc import Iterator from enum import IntEnum +from types import ModuleType +from typing import TYPE_CHECKING, Any, Final, Literal, SupportsIndex -from ._creation_functions import asarray +import numpy as np +import numpy.typing as npt + +from ._creation_functions import Undef, _undef, asarray from ._dtypes import ( - _DType, + DType, _all_dtypes, _boolean_dtypes, + _complex_floating_dtypes, + _dtype_categories, + _floating_dtypes, _integer_dtypes, _integer_or_boolean_dtypes, - _floating_dtypes, - _real_floating_dtypes, - _complex_floating_dtypes, _numeric_dtypes, - _result_type, - _dtype_categories, + _real_floating_dtypes, _real_to_complex_map, + _result_type, ) from ._flags import get_array_api_strict_flags, set_array_api_strict_flags +from ._typing import PyCapsule -from typing import TYPE_CHECKING, SupportsIndex -import types +if sys.version_info >= (3, 10): + from types import EllipsisType +elif TYPE_CHECKING: + from typing_extensions import EllipsisType +else: + EllipsisType = type(Ellipsis) -if TYPE_CHECKING: - from typing import Optional, Tuple, Union, Any - from ._typing import PyCapsule, Dtype - import numpy.typing as npt - -import numpy as np class Device: - def __init__(self, device="CPU_DEVICE"): + _device: Final[str] + __slots__ = ("_device", "__weakref__") + + def __init__(self, device: str = "CPU_DEVICE"): if device not in ("CPU_DEVICE", "device1", "device2"): raise ValueError(f"The device '{device}' is not a valid choice.") self._device = device - def __repr__(self): + def __repr__(self) -> str: return f"array_api_strict.Device('{self._device}')" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Device): return False return self._device == other._device - def __hash__(self): + def __hash__(self) -> int: return hash(("Device", self._device)) CPU_DEVICE = Device() ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2")) -_default = object() - # See https://github.com/data-apis/array-api-strict/issues/67 and the comment # on __array__ below. _allow_array = True + class Array: """ n-d array object for the array API namespace. @@ -87,12 +95,16 @@ class Array: functions, such as asarray(). """ + _array: npt.NDArray[Any] + _dtype: DType + _device: Device + __slots__ = ("_array", "_dtype", "_device", "__weakref__") # Use a custom constructor instead of __init__, as manually initializing # this class is not supported API. @classmethod - def _new(cls, x, /, device): + def _new(cls, x: npt.NDArray[Any] | np.generic, /, device: Device | None) -> Array: """ This is a private method for initializing the array API Array object. @@ -107,7 +119,7 @@ def _new(cls, x, /, device): if isinstance(x, np.generic): # Convert the array scalar to a 0-D array x = np.asarray(x) - _dtype = _DType(x.dtype) + _dtype = DType(x.dtype) if _dtype not in _all_dtypes: raise TypeError( f"The array_api_strict namespace does not support the dtype '{x.dtype}'" @@ -120,7 +132,7 @@ def _new(cls, x, /, device): return obj # Prevent Array() from working - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: object, **kwargs: object) -> Array: raise TypeError( "The array_api_strict Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead." ) @@ -128,7 +140,7 @@ def __new__(cls, *args, **kwargs): # These functions are not required by the spec, but are implemented for # the sake of usability. - def __repr__(self: Array, /) -> str: + def __repr__(self) -> str: """ Performs the operation __repr__. """ @@ -159,7 +171,9 @@ def __repr__(self: Array, /) -> str: # This was implemented historically for compatibility, and removing it has # caused issues for some libraries (see # https://github.com/data-apis/array-api-strict/issues/67). - def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]: + def __array__( + self, dtype: None | np.dtype[Any] = None, copy: None | bool = None + ) -> npt.NDArray[Any]: # We have to allow this to be internally enabled as there's no other # easy way to parse a list of Array objects in asarray(). if _allow_array: @@ -184,7 +198,9 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None # spec in places where it either deviates from or is more strict than # NumPy behavior - def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array: + def _check_allowed_dtypes( + self, other: Array | bool | int | float | complex, dtype_category: str, op: str + ) -> Array: """ Helper function for operators to only allow specific input dtypes @@ -197,7 +213,7 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor if self.dtype not in _dtype_categories[dtype_category]: raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") - if isinstance(other, (int, complex, float, bool)): + if isinstance(other, (bool, int, float, complex)): other = self._promote_scalar(other) elif isinstance(other, Array): if other.dtype not in _dtype_categories[dtype_category]: @@ -225,16 +241,18 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor return other - def _check_device(self, other): + def _check_device(self, other: Array | bool | int | float | complex) -> None: """Check that other is on a device compatible with the current array""" - if isinstance(other, (int, complex, float, bool)): + if isinstance(other, (bool, int, float, complex)): return elif isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") + else: + raise TypeError(f"Expected Array | python scalar; got {type(other)}") # Helper function to match the type promotion rules in the spec - def _promote_scalar(self, scalar): + def _promote_scalar(self, scalar: bool | int | float | complex) -> Array: """ Returns a promoted version of a Python scalar appropriate for use with operations on self. @@ -291,7 +309,7 @@ def _promote_scalar(self, scalar): return Array._new(np.array(scalar, dtype=target_dtype._np_dtype), device=self.device) @staticmethod - def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: + def _normalize_two_args(x1: Array, x2: Array) -> tuple[Array, Array]: """ Normalize inputs to two arg functions to fix type promotion rules @@ -327,7 +345,17 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: # Note: A large fraction of allowed indices are disallowed here (see the # docstring below) - def _validate_index(self, key, op="getitem"): + def _validate_index( + self, + key: ( + int + | slice + | EllipsisType + | Array + | tuple[int | slice | EllipsisType | Array | None, ...] + ), + op: Literal["getitem", "setitem"] = "getitem", + ) -> None: """ Validate an index according to the array API. @@ -509,7 +537,7 @@ def _validate_index(self, key, op="getitem"): # Everything below this line is required by the spec. - def __abs__(self: Array, /) -> Array: + def __abs__(self) -> Array: """ Performs the operation __abs__. """ @@ -518,7 +546,7 @@ def __abs__(self: Array, /) -> Array: res = self._array.__abs__() return self.__class__._new(res, device=self.device) - def __add__(self: Array, other: Union[int, float, Array], /) -> Array: + def __add__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __add__. """ @@ -530,7 +558,7 @@ def __add__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__add__(other._array) return self.__class__._new(res, device=self.device) - def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __and__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __and__. """ @@ -542,9 +570,7 @@ def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__and__(other._array) return self.__class__._new(res, device=self.device) - def __array_namespace__( - self: Array, /, *, api_version: Optional[str] = None - ) -> types.ModuleType: + def __array_namespace__(self, /, *, api_version: str | None = None) -> ModuleType: """ Return the array_api_strict namespace corresponding to api_version. @@ -563,7 +589,7 @@ def __array_namespace__( import array_api_strict return array_api_strict - def __bool__(self: Array, /) -> bool: + def __bool__(self) -> bool: """ Performs the operation __bool__. """ @@ -573,7 +599,7 @@ def __bool__(self: Array, /) -> bool: res = self._array.__bool__() return res - def __complex__(self: Array, /) -> complex: + def __complex__(self) -> complex: """ Performs the operation __complex__. """ @@ -584,52 +610,52 @@ def __complex__(self: Array, /) -> complex: return res def __dlpack__( - self: Array, + self, /, *, - stream: Optional[Union[int, Any]] = None, - max_version: Optional[tuple[int, int]] = _default, - dl_device: Optional[tuple[IntEnum, int]] = _default, - copy: Optional[bool] = _default, + stream: Any = None, + max_version: tuple[int, int] | None | Undef = _undef, + dl_device: tuple[IntEnum, int] | None | Undef = _undef, + copy: bool | None | Undef = _undef, ) -> PyCapsule: """ Performs the operation __dlpack__. """ if get_array_api_strict_flags()['api_version'] < '2023.12': - if max_version is not _default: + if max_version is not _undef: raise ValueError("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API") - if dl_device is not _default: + if dl_device is not _undef: raise ValueError("The device argument to __dlpack__ requires at least version 2023.12 of the array API") - if copy is not _default: + if copy is not _undef: raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API") if np.lib.NumpyVersion(np.__version__) < '2.1.0': - if max_version not in [_default, None]: + if max_version not in [_undef, None]: raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented") - if dl_device not in [_default, None]: + if dl_device not in [_undef, None]: raise NotImplementedError("The device argument to __dlpack__ is not yet implemented") - if copy not in [_default, None]: + if copy not in [_undef, None]: raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented") return self._array.__dlpack__(stream=stream) else: kwargs = {'stream': stream} - if max_version is not _default: + if max_version is not _undef: kwargs['max_version'] = max_version - if dl_device is not _default: + if dl_device is not _undef: kwargs['dl_device'] = dl_device - if copy is not _default: + if copy is not _undef: kwargs['copy'] = copy return self._array.__dlpack__(**kwargs) - def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]: + def __dlpack_device__(self) -> tuple[IntEnum, int]: """ Performs the operation __dlpack_device__. """ # Note: device support is required for this return self._array.__dlpack_device__() - def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: + def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override] """ Performs the operation __eq__. """ @@ -643,7 +669,7 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: res = self._array.__eq__(other._array) return self.__class__._new(res, device=self.device) - def __float__(self: Array, /) -> float: + def __float__(self) -> float: """ Performs the operation __float__. """ @@ -655,7 +681,7 @@ def __float__(self: Array, /) -> float: res = self._array.__float__() return res - def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array: + def __floordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __floordiv__. """ @@ -667,7 +693,7 @@ def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__floordiv__(other._array) return self.__class__._new(res, device=self.device) - def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: + def __ge__(self, other: Array | int | float, /) -> Array: """ Performs the operation __ge__. """ @@ -680,14 +706,15 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: return self.__class__._new(res, device=self.device) def __getitem__( - self: Array, - key: Union[ - int, - slice, - ellipsis, # noqa: F821 - Tuple[Union[int, slice, ellipsis, None], ...], # noqa: F821 - Array, - ], + self, + key: ( + int + | slice + | EllipsisType + | Array + | None + | tuple[int | slice | EllipsisType | Array | None, ...] + ), /, ) -> Array: """ @@ -696,14 +723,13 @@ def __getitem__( # XXX Does key have to be on the same device? Is there an exception for CPU_DEVICE? # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index - self._validate_index(key) - if isinstance(key, Array): - # Indexing self._array with array_api_strict arrays can be erroneous - key = key._array - res = self._array.__getitem__(key) + self._validate_index(key, op="getitem") + # Indexing self._array with array_api_strict arrays can be erroneous + np_key = key._array if isinstance(key, Array) else key + res = self._array.__getitem__(np_key) return self._new(res, device=self.device) - def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: + def __gt__(self, other: Array | int | float, /) -> Array: """ Performs the operation __gt__. """ @@ -715,7 +741,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__gt__(other._array) return self.__class__._new(res, device=other.device) - def __int__(self: Array, /) -> int: + def __int__(self) -> int: """ Performs the operation __int__. """ @@ -727,14 +753,14 @@ def __int__(self: Array, /) -> int: res = self._array.__int__() return res - def __index__(self: Array, /) -> int: + def __index__(self) -> int: """ Performs the operation __index__. """ res = self._array.__index__() return res - def __invert__(self: Array, /) -> Array: + def __invert__(self) -> Array: """ Performs the operation __invert__. """ @@ -743,7 +769,7 @@ def __invert__(self: Array, /) -> Array: res = self._array.__invert__() return self.__class__._new(res, device=self.device) - def __iter__(self: Array, /): + def __iter__(self) -> Iterator[Array]: """ Performs the operation __iter__. """ @@ -758,7 +784,7 @@ def __iter__(self: Array, /): # implemented, which implies iteration on 1-D arrays. return (Array._new(i, device=self.device) for i in self._array) - def __le__(self: Array, other: Union[int, float, Array], /) -> Array: + def __le__(self, other: Array | int | float, /) -> Array: """ Performs the operation __le__. """ @@ -770,7 +796,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__le__(other._array) return self.__class__._new(res, device=self.device) - def __lshift__(self: Array, other: Union[int, Array], /) -> Array: + def __lshift__(self, other: Array | int, /) -> Array: """ Performs the operation __lshift__. """ @@ -782,7 +808,7 @@ def __lshift__(self: Array, other: Union[int, Array], /) -> Array: res = self._array.__lshift__(other._array) return self.__class__._new(res, device=self.device) - def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: + def __lt__(self, other: Array | int | float, /) -> Array: """ Performs the operation __lt__. """ @@ -794,7 +820,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__lt__(other._array) return self.__class__._new(res, device=self.device) - def __matmul__(self: Array, other: Array, /) -> Array: + def __matmul__(self, other: Array, /) -> Array: """ Performs the operation __matmul__. """ @@ -807,7 +833,7 @@ def __matmul__(self: Array, other: Array, /) -> Array: res = self._array.__matmul__(other._array) return self.__class__._new(res, device=self.device) - def __mod__(self: Array, other: Union[int, float, Array], /) -> Array: + def __mod__(self, other: Array | int | float, /) -> Array: """ Performs the operation __mod__. """ @@ -819,7 +845,7 @@ def __mod__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__mod__(other._array) return self.__class__._new(res, device=self.device) - def __mul__(self: Array, other: Union[int, float, Array], /) -> Array: + def __mul__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __mul__. """ @@ -831,7 +857,7 @@ def __mul__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__mul__(other._array) return self.__class__._new(res, device=self.device) - def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: + def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override] """ Performs the operation __ne__. """ @@ -843,7 +869,7 @@ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array: res = self._array.__ne__(other._array) return self.__class__._new(res, device=self.device) - def __neg__(self: Array, /) -> Array: + def __neg__(self) -> Array: """ Performs the operation __neg__. """ @@ -852,7 +878,7 @@ def __neg__(self: Array, /) -> Array: res = self._array.__neg__() return self.__class__._new(res, device=self.device) - def __or__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __or__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __or__. """ @@ -864,7 +890,7 @@ def __or__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__or__(other._array) return self.__class__._new(res, device=self.device) - def __pos__(self: Array, /) -> Array: + def __pos__(self) -> Array: """ Performs the operation __pos__. """ @@ -873,11 +899,11 @@ def __pos__(self: Array, /) -> Array: res = self._array.__pos__() return self.__class__._new(res, device=self.device) - def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: + def __pow__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __pow__. """ - from ._elementwise_functions import pow + from ._elementwise_functions import pow # type: ignore[attr-defined] self._check_device(other) other = self._check_allowed_dtypes(other, "numeric", "__pow__") @@ -887,7 +913,7 @@ def __pow__(self: Array, other: Union[int, float, Array], /) -> Array: # arrays, so we use pow() here instead. return pow(self, other) - def __rshift__(self: Array, other: Union[int, Array], /) -> Array: + def __rshift__(self, other: Array | int, /) -> Array: """ Performs the operation __rshift__. """ @@ -901,10 +927,16 @@ def __rshift__(self: Array, other: Union[int, Array], /) -> Array: def __setitem__( self, - key: Union[ - int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array # noqa: F821 - ], - value: Union[int, float, bool, Array], + # Almost same as __getitem__ key but doesn't accept None + # or integer arrays + key: ( + int + | slice + | EllipsisType + | Array + | tuple[int | slice | EllipsisType, ...] + ), + value: Array | bool | int | float | complex, /, ) -> None: """ @@ -913,12 +945,11 @@ def __setitem__( # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index self._validate_index(key, op="setitem") - if isinstance(key, Array): - # Indexing self._array with array_api_strict arrays can be erroneous - key = key._array - self._array.__setitem__(key, asarray(value)._array) + # Indexing self._array with array_api_strict arrays can be erroneous + np_key = key._array if isinstance(key, Array) else key + self._array.__setitem__(np_key, asarray(value)._array) - def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: + def __sub__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __sub__. """ @@ -932,7 +963,7 @@ def __sub__(self: Array, other: Union[int, float, Array], /) -> Array: # PEP 484 requires int to be a subtype of float, but __truediv__ should # not accept int. - def __truediv__(self: Array, other: Union[float, Array], /) -> Array: + def __truediv__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __truediv__. """ @@ -944,7 +975,7 @@ def __truediv__(self: Array, other: Union[float, Array], /) -> Array: res = self._array.__truediv__(other._array) return self.__class__._new(res, device=self.device) - def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __xor__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __xor__. """ @@ -956,7 +987,7 @@ def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__xor__(other._array) return self.__class__._new(res, device=self.device) - def __iadd__(self: Array, other: Union[int, float, Array], /) -> Array: + def __iadd__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __iadd__. """ @@ -967,7 +998,7 @@ def __iadd__(self: Array, other: Union[int, float, Array], /) -> Array: self._array.__iadd__(other._array) return self - def __radd__(self: Array, other: Union[int, float, Array], /) -> Array: + def __radd__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __radd__. """ @@ -979,7 +1010,7 @@ def __radd__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__radd__(other._array) return self.__class__._new(res, device=self.device) - def __iand__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __iand__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __iand__. """ @@ -990,7 +1021,7 @@ def __iand__(self: Array, other: Union[int, bool, Array], /) -> Array: self._array.__iand__(other._array) return self - def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __rand__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __rand__. """ @@ -1002,7 +1033,7 @@ def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__rand__(other._array) return self.__class__._new(res, device=self.device) - def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: + def __ifloordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __ifloordiv__. """ @@ -1013,7 +1044,7 @@ def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: self._array.__ifloordiv__(other._array) return self - def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: + def __rfloordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __rfloordiv__. """ @@ -1025,7 +1056,7 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__rfloordiv__(other._array) return self.__class__._new(res, device=self.device) - def __ilshift__(self: Array, other: Union[int, Array], /) -> Array: + def __ilshift__(self, other: Array | int, /) -> Array: """ Performs the operation __ilshift__. """ @@ -1036,7 +1067,7 @@ def __ilshift__(self: Array, other: Union[int, Array], /) -> Array: self._array.__ilshift__(other._array) return self - def __rlshift__(self: Array, other: Union[int, Array], /) -> Array: + def __rlshift__(self, other: Array | int, /) -> Array: """ Performs the operation __rlshift__. """ @@ -1048,7 +1079,7 @@ def __rlshift__(self: Array, other: Union[int, Array], /) -> Array: res = self._array.__rlshift__(other._array) return self.__class__._new(res, device=self.device) - def __imatmul__(self: Array, other: Array, /) -> Array: + def __imatmul__(self, other: Array, /) -> Array: """ Performs the operation __imatmul__. """ @@ -1061,7 +1092,7 @@ def __imatmul__(self: Array, other: Array, /) -> Array: res = self._array.__imatmul__(other._array) return self.__class__._new(res, device=self.device) - def __rmatmul__(self: Array, other: Array, /) -> Array: + def __rmatmul__(self, other: Array, /) -> Array: """ Performs the operation __rmatmul__. """ @@ -1074,7 +1105,7 @@ def __rmatmul__(self: Array, other: Array, /) -> Array: res = self._array.__rmatmul__(other._array) return self.__class__._new(res, device=self.device) - def __imod__(self: Array, other: Union[int, float, Array], /) -> Array: + def __imod__(self, other: Array | int | float, /) -> Array: """ Performs the operation __imod__. """ @@ -1084,7 +1115,7 @@ def __imod__(self: Array, other: Union[int, float, Array], /) -> Array: self._array.__imod__(other._array) return self - def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array: + def __rmod__(self, other: Array | int | float, /) -> Array: """ Performs the operation __rmod__. """ @@ -1096,7 +1127,7 @@ def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__rmod__(other._array) return self.__class__._new(res, device=self.device) - def __imul__(self: Array, other: Union[int, float, Array], /) -> Array: + def __imul__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __imul__. """ @@ -1106,7 +1137,7 @@ def __imul__(self: Array, other: Union[int, float, Array], /) -> Array: self._array.__imul__(other._array) return self - def __rmul__(self: Array, other: Union[int, float, Array], /) -> Array: + def __rmul__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __rmul__. """ @@ -1118,7 +1149,7 @@ def __rmul__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__rmul__(other._array) return self.__class__._new(res, device=self.device) - def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __ior__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __ior__. """ @@ -1128,7 +1159,7 @@ def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array: self._array.__ior__(other._array) return self - def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __ror__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __ror__. """ @@ -1140,7 +1171,7 @@ def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__ror__(other._array) return self.__class__._new(res, device=self.device) - def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array: + def __ipow__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __ipow__. """ @@ -1150,11 +1181,11 @@ def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array: self._array.__ipow__(other._array) return self - def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array: + def __rpow__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __rpow__. """ - from ._elementwise_functions import pow + from ._elementwise_functions import pow # type: ignore[attr-defined] other = self._check_allowed_dtypes(other, "numeric", "__rpow__") if other is NotImplemented: @@ -1163,7 +1194,7 @@ def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array: # for 0-d arrays, so we use pow() here instead. return pow(other, self) - def __irshift__(self: Array, other: Union[int, Array], /) -> Array: + def __irshift__(self, other: Array | int, /) -> Array: """ Performs the operation __irshift__. """ @@ -1173,7 +1204,7 @@ def __irshift__(self: Array, other: Union[int, Array], /) -> Array: self._array.__irshift__(other._array) return self - def __rrshift__(self: Array, other: Union[int, Array], /) -> Array: + def __rrshift__(self, other: Array | int, /) -> Array: """ Performs the operation __rrshift__. """ @@ -1185,7 +1216,7 @@ def __rrshift__(self: Array, other: Union[int, Array], /) -> Array: res = self._array.__rrshift__(other._array) return self.__class__._new(res, device=self.device) - def __isub__(self: Array, other: Union[int, float, Array], /) -> Array: + def __isub__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __isub__. """ @@ -1195,7 +1226,7 @@ def __isub__(self: Array, other: Union[int, float, Array], /) -> Array: self._array.__isub__(other._array) return self - def __rsub__(self: Array, other: Union[int, float, Array], /) -> Array: + def __rsub__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __rsub__. """ @@ -1207,7 +1238,7 @@ def __rsub__(self: Array, other: Union[int, float, Array], /) -> Array: res = self._array.__rsub__(other._array) return self.__class__._new(res, device=self.device) - def __itruediv__(self: Array, other: Union[float, Array], /) -> Array: + def __itruediv__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __itruediv__. """ @@ -1217,7 +1248,7 @@ def __itruediv__(self: Array, other: Union[float, Array], /) -> Array: self._array.__itruediv__(other._array) return self - def __rtruediv__(self: Array, other: Union[float, Array], /) -> Array: + def __rtruediv__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __rtruediv__. """ @@ -1229,7 +1260,7 @@ def __rtruediv__(self: Array, other: Union[float, Array], /) -> Array: res = self._array.__rtruediv__(other._array) return self.__class__._new(res, device=self.device) - def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __ixor__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __ixor__. """ @@ -1239,7 +1270,7 @@ def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array: self._array.__ixor__(other._array) return self - def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array: + def __rxor__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __rxor__. """ @@ -1251,7 +1282,7 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array: res = self._array.__rxor__(other._array) return self.__class__._new(res, device=self.device) - def to_device(self: Array, device: Device, /, stream: None = None) -> Array: + def to_device(self, device: Device, /, stream: None = None) -> Array: if stream is not None: raise ValueError("The stream argument to to_device() is not supported") if device == self._device: @@ -1262,7 +1293,7 @@ def to_device(self: Array, device: Device, /, stream: None = None) -> Array: raise ValueError(f"Unsupported device {device!r}") @property - def dtype(self) -> Dtype: + def dtype(self) -> DType: """ Array API compatible wrapper for :py:meth:`np.ndarray.dtype `. @@ -1290,7 +1321,7 @@ def ndim(self) -> int: return self._array.ndim @property - def shape(self) -> Tuple[int, ...]: + def shape(self) -> tuple[int, ...]: """ Array API compatible wrapper for :py:meth:`np.ndarray.shape `. diff --git a/array_api_strict/_constants.py b/array_api_strict/_constants.py index 15ab81d..d78354b 100644 --- a/array_api_strict/_constants.py +++ b/array_api_strict/_constants.py @@ -4,4 +4,4 @@ inf = np.inf nan = np.nan pi = np.pi -newaxis = np.newaxis +newaxis: None = np.newaxis diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 460dba9..3b80b8a 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -1,23 +1,33 @@ from __future__ import annotations +from collections.abc import Generator from contextlib import contextmanager -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from enum import Enum +from typing import TYPE_CHECKING, Literal -if TYPE_CHECKING: - from ._typing import ( - Array, - Device, - Dtype, - NestedSequence, - SupportsBufferProtocol, - ) -from ._dtypes import _DType, _all_dtypes +import numpy as np + +from ._dtypes import DType, _all_dtypes, _np_dtype from ._flags import get_array_api_strict_flags +from ._typing import NestedSequence, SupportsBufferProtocol, SupportsDLPack + +if TYPE_CHECKING: + # TODO import from typing (requires Python >=3.13) + from typing_extensions import TypeIs + + # Circular import + from ._array_object import Array, Device + + +class Undef(Enum): + UNDEF = 0 + + +_undef = Undef.UNDEF -import numpy as np @contextmanager -def allow_array(): +def allow_array() -> Generator[None]: """ Temporarily enable Array.__array__. This is needed for np.array to parse list of lists of Array objects. @@ -30,22 +40,25 @@ def allow_array(): finally: _array_object._allow_array = original_value -def _check_valid_dtype(dtype): + +def _check_valid_dtype(dtype: DType | None) -> None: # Note: Only spelling dtypes as the dtype objects is supported. if dtype not in (None,) + _all_dtypes: raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}") -def _supports_buffer_protocol(obj): + +def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]: try: - memoryview(obj) + memoryview(obj) # type: ignore[arg-type] except TypeError: return False return True -def _check_device(device): + +def _check_device(device: Device | None) -> None: # _array_object imports in this file are inside the functions to avoid # circular imports - from ._array_object import Device, ALL_DEVICES + from ._array_object import ALL_DEVICES, Device if device is not None and not isinstance(device, Device): raise ValueError(f"Unsupported device {device!r}") @@ -53,20 +66,20 @@ def _check_device(device): if device is not None and device not in ALL_DEVICES: raise ValueError(f"Unsupported device {device!r}") + def asarray( - obj: Union[ - Array, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: Array + | bool + | int + | float + | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol, /, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - copy: Optional[bool] = None, + dtype: DType | None = None, + device: Device | None = None, + copy: bool | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.asarray `. @@ -118,13 +131,13 @@ def asarray( def arange( - start: Union[int, float], + start: int | float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: int | float | None = None, + step: int | float = 1, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.arange `. @@ -136,16 +149,17 @@ def arange( _check_valid_dtype(dtype) _check_device(device) - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype), device=device) + return Array._new( + np.arange(start, stop, step, dtype=_np_dtype(dtype)), + device=device, + ) def empty( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.empty `. @@ -157,13 +171,11 @@ def empty( _check_valid_dtype(dtype) _check_device(device) - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.empty(shape, dtype=dtype), device=device) + return Array._new(np.empty(shape, dtype=_np_dtype(dtype)), device=device) def empty_like( - x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None + x: Array, /, *, dtype: DType | None = None, device: Device | None = None ) -> Array: """ Array API compatible wrapper for :py:func:`np.empty_like `. @@ -177,19 +189,17 @@ def empty_like( if device is None: device = x.device - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.empty_like(x._array, dtype=dtype), device=device) + return Array._new(np.empty_like(x._array, dtype=_np_dtype(dtype)), device=device) def eye( n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, k: int = 0, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.eye `. @@ -201,45 +211,43 @@ def eye( _check_valid_dtype(dtype) _check_device(device) - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype), device=device) - + return Array._new( + np.eye(n_rows, M=n_cols, k=k, dtype=_np_dtype(dtype)), device=device + ) -_default = object() def from_dlpack( - x: object, + x: SupportsDLPack, /, *, - device: Optional[Device] = _default, - copy: Optional[bool] = _default, + device: Device | Undef | None = _undef, + copy: bool | Undef | None = _undef, ) -> Array: from ._array_object import Array if get_array_api_strict_flags()['api_version'] < '2023.12': - if device is not _default: + if device is not _undef: raise ValueError("The device argument to from_dlpack requires at least version 2023.12 of the array API") - if copy is not _default: + if copy is not _undef: raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API") # Going to wait for upstream numpy support - if device is not _default: + if device is not _undef: _check_device(device) else: device = None - if copy not in [_default, None]: + if copy not in [_undef, None]: raise NotImplementedError("The copy argument to from_dlpack is not yet implemented") return Array._new(np.from_dlpack(x), device=device) def full( - shape: Union[int, Tuple[int, ...]], - fill_value: Union[int, float], + shape: int | tuple[int, ...], + fill_value: bool | int | float | complex, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.full `. @@ -253,10 +261,8 @@ def full( if isinstance(fill_value, Array) and fill_value.ndim == 0: fill_value = fill_value._array - if dtype is not None: - dtype = dtype._np_dtype - res = np.full(shape, fill_value, dtype=dtype) - if _DType(res.dtype) not in _all_dtypes: + res = np.full(shape, fill_value, dtype=_np_dtype(dtype)) + if DType(res.dtype) not in _all_dtypes: # This will happen if the fill value is not something that NumPy # coerces to one of the acceptable dtypes. raise TypeError("Invalid input to full") @@ -266,10 +272,10 @@ def full( def full_like( x: Array, /, - fill_value: Union[int, float], + fill_value: bool | int | float | complex, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.full_like `. @@ -283,10 +289,8 @@ def full_like( if device is None: device = x.device - if dtype is not None: - dtype = dtype._np_dtype - res = np.full_like(x._array, fill_value, dtype=dtype) - if _DType(res.dtype) not in _all_dtypes: + res = np.full_like(x._array, fill_value, dtype=_np_dtype(dtype)) + if DType(res.dtype) not in _all_dtypes: # This will happen if the fill value is not something that NumPy # coerces to one of the acceptable dtypes. raise TypeError("Invalid input to full_like") @@ -294,13 +298,13 @@ def full_like( def linspace( - start: Union[int, float], - stop: Union[int, float], + start: int | float | complex, + stop: int | float | complex, /, num: int, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, ) -> Array: """ @@ -313,12 +317,13 @@ def linspace( _check_valid_dtype(dtype) _check_device(device) - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint), device=device) + return Array._new( + np.linspace(start, stop, num, dtype=_np_dtype(dtype), endpoint=endpoint), + device=device, + ) -def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: +def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> list[Array]: """ Array API compatible wrapper for :py:func:`np.meshgrid `. @@ -348,10 +353,10 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: def ones( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.ones `. @@ -363,13 +368,11 @@ def ones( _check_valid_dtype(dtype) _check_device(device) - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.ones(shape, dtype=dtype), device=device) + return Array._new(np.ones(shape, dtype=_np_dtype(dtype)), device=device) def ones_like( - x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None + x: Array, /, *, dtype: DType | None = None, device: Device | None = None ) -> Array: """ Array API compatible wrapper for :py:func:`np.ones_like `. @@ -383,9 +386,7 @@ def ones_like( if device is None: device = x.device - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.ones_like(x._array, dtype=dtype), device=device) + return Array._new(np.ones_like(x._array, dtype=_np_dtype(dtype)), device=device) def tril(x: Array, /, *, k: int = 0) -> Array: @@ -417,10 +418,10 @@ def triu(x: Array, /, *, k: int = 0) -> Array: def zeros( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.zeros `. @@ -432,13 +433,11 @@ def zeros( _check_valid_dtype(dtype) _check_device(device) - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.zeros(shape, dtype=dtype), device=device) + return Array._new(np.zeros(shape, dtype=_np_dtype(dtype)), device=device) def zeros_like( - x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None + x: Array, /, *, dtype: DType | None = None, device: Device | None = None ) -> Array: """ Array API compatible wrapper for :py:func:`np.zeros_like `. @@ -452,6 +451,4 @@ def zeros_like( if device is None: device = x.device - if dtype is not None: - dtype = dtype._np_dtype - return Array._new(np.zeros_like(x._array, dtype=dtype), device=device) + return Array._new(np.zeros_like(x._array, dtype=_np_dtype(dtype)), device=device) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 1643043..7dc918d 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -1,38 +1,37 @@ from __future__ import annotations -from ._array_object import Array -from ._creation_functions import _check_device +from dataclasses import dataclass + +import numpy as np + +from ._array_object import Array, Device +from ._creation_functions import Undef, _check_device, _undef from ._dtypes import ( - _DType, + DType, _all_dtypes, _boolean_dtypes, - _signed_integer_dtypes, - _unsigned_integer_dtypes, - _integer_dtypes, - _real_floating_dtypes, _complex_floating_dtypes, + _integer_dtypes, _numeric_dtypes, + _real_floating_dtypes, _result_type, + _signed_integer_dtypes, + _unsigned_integer_dtypes, ) from ._flags import get_array_api_strict_flags -from dataclasses import dataclass -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import List, Tuple, Union, Optional - from ._typing import Dtype, Device - -import numpy as np - -# Use to emulate the asarray(device) argument not existing in 2022.12 -_default = object() # Note: astype is a function, not an array method as in NumPy. def astype( - x: Array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = _default + x: Array, + dtype: DType, + /, + *, + copy: bool = True, + # _default is used to emulate the device argument not existing in 2022.12 + device: Device | Undef | None = _undef, ) -> Array: - if device is not _default: + if device is not _undef: if get_array_api_strict_flags()['api_version'] >= '2023.12': _check_device(device) else: @@ -52,7 +51,7 @@ def astype( return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy), device=device) -def broadcast_arrays(*arrays: Array) -> List[Array]: +def broadcast_arrays(*arrays: Array) -> list[Array]: """ Array API compatible wrapper for :py:func:`np.broadcast_arrays `. @@ -65,7 +64,7 @@ def broadcast_arrays(*arrays: Array) -> List[Array]: ] -def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: +def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.broadcast_to `. @@ -76,7 +75,7 @@ def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array: return Array._new(np.broadcast_to(x._array, shape), device=x.device) -def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool: +def can_cast(from_: DType | Array, to: DType, /) -> bool: """ Array API compatible wrapper for :py:func:`np.can_cast `. @@ -112,7 +111,7 @@ class finfo_object: max: float min: float smallest_normal: float - dtype: Dtype + dtype: DType @dataclass @@ -120,18 +119,17 @@ class iinfo_object: bits: int max: int min: int - dtype: Dtype + dtype: DType -def finfo(type: Union[Dtype, Array], /) -> finfo_object: +def finfo(type: DType | Array, /) -> finfo_object: """ Array API compatible wrapper for :py:func:`np.finfo `. See its docstring for more information. """ - if isinstance(type, _DType): - type = type._np_dtype - fi = np.finfo(type) + np_type = type._array if isinstance(type, Array) else type._np_dtype + fi = np.finfo(np_type) # Note: The types of the float data here are float, whereas in NumPy they # are scalars of the corresponding float dtype. return finfo_object( @@ -140,35 +138,33 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object: float(fi.max), float(fi.min), float(fi.smallest_normal), - fi.dtype, + DType(fi.dtype), ) -def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: +def iinfo(type: DType | Array, /) -> iinfo_object: """ Array API compatible wrapper for :py:func:`np.iinfo `. See its docstring for more information. """ - if isinstance(type, _DType): - type = type._np_dtype - ii = np.iinfo(type) - return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype) + np_type = type._array if isinstance(type, Array) else type._np_dtype + ii = np.iinfo(np_type) + return iinfo_object(ii.bits, ii.max, ii.min, DType(ii.dtype)) # Note: isdtype is a new function from the 2022.12 array API specification. -def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]] -) -> bool: +def isdtype(dtype: DType, kind: DType | str | tuple[DType | str, ...]) -> bool: """ - Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. + Returns a boolean indicating whether a provided dtype is of a specified + data type ``kind``. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html for more details """ - if not isinstance(dtype, _DType): - raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}") + if not isinstance(dtype, DType): + raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}") if isinstance(kind, tuple): # Disallow nested tuples @@ -197,7 +193,10 @@ def isdtype( else: raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}") -def result_type(*arrays_and_dtypes: Union[Array, Dtype, int, float, complex, bool]) -> Dtype: + +def result_type( + *arrays_and_dtypes: DType | Array | bool | int | float | complex, +) -> DType: """ Array API compatible wrapper for :py:func:`np.result_type `. @@ -219,15 +218,15 @@ def result_type(*arrays_and_dtypes: Union[Array, Dtype, int, float, complex, boo A.append(a) # remove python scalars - A = [a for a in A if not isinstance(a, (bool, int, float, complex))] + B = [a for a in A if not isinstance(a, (bool, int, float, complex))] - if len(A) == 0: + if len(B) == 0: raise ValueError("at least one array or dtype is required") - elif len(A) == 1: - result = A[0] + elif len(B) == 1: + result = B[0] else: - t = A[0] - for t2 in A[1:]: + t = B[0] + for t2 in B[1:]: t = _result_type(t, t2) result = t diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index 66304dd..513650b 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -1,19 +1,27 @@ +from __future__ import annotations + +import builtins import warnings +from typing import Any, Final import numpy as np +import numpy.typing as npt # Note: we wrap the NumPy dtype objects in a bare class, so that none of the # additional methods and behaviors of NumPy dtype objects are exposed. -class _DType: - def __init__(self, np_dtype): - np_dtype = np.dtype(np_dtype) - self._np_dtype = np_dtype - def __repr__(self): +class DType: + _np_dtype: Final[np.dtype[Any]] + __slots__ = ("_np_dtype", "__weakref__") + + def __init__(self, np_dtype: npt.DTypeLike): + self._np_dtype = np.dtype(np_dtype) + + def __repr__(self) -> str: return f"array_api_strict.{self._np_dtype.name}" - def __eq__(self, other): + def __eq__(self, other: object) -> builtins.bool: # See https://github.com/numpy/numpy/pull/25370/files#r1423259515. # Avoid the user error of array_api_strict.float32 == numpy.float32, # which gives False. Making == error is probably too egregious, so @@ -26,12 +34,13 @@ def __eq__(self, other): a NumPy native dtype object, but you probably don't want to do this. \ array_api_strict dtype objects compare unequal to their NumPy equivalents. \ Such cross-library comparison is not supported by the standard.""", - stacklevel=2) - if not isinstance(other, _DType): + stacklevel=2, + ) + if not isinstance(other, DType): return NotImplemented return self._np_dtype == other._np_dtype - def __hash__(self): + def __hash__(self) -> int: # Note: this is not strictly required # (https://github.com/data-apis/array-api/issues/582), but makes the # dtype objects much easier to work with here and elsewhere if they @@ -39,20 +48,24 @@ def __hash__(self): return hash(self._np_dtype) -int8 = _DType("int8") -int16 = _DType("int16") -int32 = _DType("int32") -int64 = _DType("int64") -uint8 = _DType("uint8") -uint16 = _DType("uint16") -uint32 = _DType("uint32") -uint64 = _DType("uint64") -float32 = _DType("float32") -float64 = _DType("float64") -complex64 = _DType("complex64") -complex128 = _DType("complex128") +def _np_dtype(dtype: DType | None) -> np.dtype[Any] | None: + return dtype._np_dtype if dtype is not None else None + + +int8 = DType("int8") +int16 = DType("int16") +int32 = DType("int32") +int64 = DType("int64") +uint8 = DType("uint8") +uint16 = DType("uint16") +uint32 = DType("uint32") +uint64 = DType("uint64") +float32 = DType("float32") +float64 = DType("float64") +complex64 = DType("complex64") +complex128 = DType("complex128") # Note: This name is changed -bool = _DType("bool") +bool = DType("bool") _all_dtypes = ( int8, @@ -212,7 +225,7 @@ def __hash__(self): } -def _result_type(type1, type2): +def _result_type(type1: DType, type2: DType) -> DType: if (type1, type2) in _promotion_table: return _promotion_table[type1, type2] raise TypeError(f"{type1} and {type2} cannot be type promoted together") diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index c11b17c..6b52a58 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -1,51 +1,50 @@ from __future__ import annotations +import numpy as np + +from ._array_object import Array +from ._creation_functions import asarray +from ._data_type_functions import broadcast_to, iinfo from ._dtypes import ( _boolean_dtypes, - _floating_dtypes, - _real_floating_dtypes, _complex_floating_dtypes, + _dtype_categories, + _floating_dtypes, _integer_dtypes, _integer_or_boolean_dtypes, - _real_numeric_dtypes, _numeric_dtypes, + _real_floating_dtypes, + _real_numeric_dtypes, _result_type, - _dtype_categories, ) -from ._array_object import Array from ._flags import requires_api_version -from ._creation_functions import asarray -from ._data_type_functions import broadcast_to, iinfo from ._helpers import _maybe_normalize_py_scalars -from typing import Optional, Union - -import numpy as np - def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func): """Base implementation of a binary function, `func_name`, defined for - dtypes from `dtype_category` + dtypes from `dtype_category` """ x1, x2 = _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name) if x1.device != x2.device: - raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") + raise ValueError( + f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined." + ) # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) return Array._new(np_func(x1._array, x2._array), device=x1.device) -_binary_docstring_template=\ -""" +_binary_docstring_template = """ Array API compatible wrapper for :py:func:`np.%s `. See its docstring for more information. """ -def create_binary_func(func_name, dtype_category, np_func): +def _create_binary_func(func_name, dtype_category, np_func): def inner(x1, x2, /) -> Array: return _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func) return inner @@ -58,7 +57,7 @@ def inner(x1, x2, /) -> Array: "real numeric": "int | float | Array", "numeric": "int | float | complex | Array", "integer": "int | Array", - "integer or boolean": "int | bool | Array", + "integer or boolean": "bool | int | Array", "boolean": "bool | Array", "real floating-point": "float | Array", "complex floating-point": "complex | Array", @@ -75,7 +74,7 @@ def inner(x1, x2, /) -> Array: "bitwise_xor": "integer or boolean", "_bitwise_left_shift": "integer", # leading underscore deliberate "_bitwise_right_shift": "integer", - # XXX: copysign: real fp or numeric? + # XXX: copysign: real fp or numeric? "copysign": "real floating-point", "divide": "floating-point", "equal": "all", @@ -105,7 +104,7 @@ def inner(x1, x2, /) -> Array: "atan2": "arctan2", "_bitwise_left_shift": "left_shift", "_bitwise_right_shift": "right_shift", - "pow": "power" + "pow": "power", } @@ -117,7 +116,7 @@ def inner(x1, x2, /) -> Array: numpy_name = _numpy_renames.get(func_name, func_name) np_func = getattr(np, numpy_name) - func = create_binary_func(func_name, dtype_category, np_func) + func = _create_binary_func(func_name, dtype_category, np_func) func.__name__ = func_name func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name) @@ -153,7 +152,7 @@ def bitwise_right_shift(x1: int | Array, x2: int | Array, /) -> Array: # clean up to not pollute the namespace -del func, create_binary_func +del func, _create_binary_func def abs(x: Array, /) -> Array: @@ -271,8 +270,8 @@ def ceil(x: Array, /) -> Array: def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, + min: Array | int | float | None = None, + max: Array | int | float | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.clip `. @@ -351,6 +350,7 @@ def clip( def _isscalar(a): return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -584,6 +584,7 @@ def reciprocal(x: Array, /) -> Array: raise TypeError("Only floating-point dtypes are allowed in reciprocal") return Array._new(np.reciprocal(x._array), device=x.device) + def round(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.round `. diff --git a/array_api_strict/_fft.py b/array_api_strict/_fft.py index c888826..2998254 100644 --- a/array_api_strict/_fft.py +++ b/array_api_strict/_fft.py @@ -1,31 +1,29 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from collections.abc import Sequence +from typing import Literal -if TYPE_CHECKING: - from typing import Union, Optional, Literal - from ._typing import Device, Dtype as DType - from collections.abc import Sequence +import numpy as np +from ._array_object import ALL_DEVICES, Array, Device +from ._data_type_functions import astype from ._dtypes import ( + DType, + _complex_floating_dtypes, _floating_dtypes, _real_floating_dtypes, - _complex_floating_dtypes, - float32, complex64, + float32, ) -from ._array_object import Array, ALL_DEVICES -from ._data_type_functions import astype from ._flags import requires_extension -import numpy as np @requires_extension('fft') def fft( x: Array, /, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: @@ -48,7 +46,7 @@ def ifft( x: Array, /, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: @@ -71,8 +69,8 @@ def fftn( x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: """ @@ -94,8 +92,8 @@ def ifftn( x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: """ @@ -117,7 +115,7 @@ def rfft( x: Array, /, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: @@ -140,7 +138,7 @@ def irfft( x: Array, /, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: @@ -163,8 +161,8 @@ def rfftn( x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: """ @@ -186,8 +184,8 @@ def irfftn( x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: """ @@ -209,7 +207,7 @@ def hfft( x: Array, /, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: @@ -232,7 +230,7 @@ def ihfft( x: Array, /, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", ) -> Array: @@ -256,8 +254,8 @@ def fftfreq( /, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None + dtype: DType | None = None, + device: Device | None = None ) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.fftfreq `. @@ -280,8 +278,8 @@ def rfftfreq( /, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None + dtype: DType | None = None, + device: Device | None = None ) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.rfftfreq `. @@ -299,7 +297,7 @@ def rfftfreq( return Array._new(np_result, device=device) @requires_extension('fft') -def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: +def fftshift(x: Array, /, *, axes: int | Sequence[int] | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.fftshift `. @@ -310,7 +308,7 @@ def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: return Array._new(np.fft.fftshift(x._array, axes=axes), device=x.device) @requires_extension('fft') -def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: +def ifftshift(x: Array, /, *, axes: int | Sequence[int] | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.ifftshift `. diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 3fce8a0..6729a4b 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -12,17 +12,32 @@ """ +from __future__ import annotations + import functools import os import warnings +from collections.abc import Callable +from types import TracebackType +from typing import TYPE_CHECKING, Any, Collection, TypeVar, cast import array_api_strict +if TYPE_CHECKING: + # TODO import from typing (requires Python >= 3.10) + from typing_extensions import ParamSpec + + P = ParamSpec("P") + +T = TypeVar("T") +_CallableT = TypeVar("_CallableT", bound=Callable[..., object]) + + supported_versions = ( "2021.12", "2022.12", "2023.12", - "2024.12" + "2024.12", ) draft_version = "2025.12" @@ -43,19 +58,23 @@ "fft": "2022.12", } -ENABLED_EXTENSIONS = default_extensions = ( +default_extensions: tuple[str, ...] = ( "linalg", "fft", ) +ENABLED_EXTENSIONS = default_extensions + + # Public functions + def set_array_api_strict_flags( *, - api_version=None, - boolean_indexing=None, - data_dependent_shapes=None, - enabled_extensions=None, -): + api_version: str | None = None, + boolean_indexing: bool | None = None, + data_dependent_shapes: bool | None = None, + enabled_extensions: Collection[str] | None = None, +) -> None: """ Set the array-api-strict flags to the specified values. @@ -178,7 +197,8 @@ def set_array_api_strict_flags( draft_version=draft_version, ) -def get_array_api_strict_flags(): + +def get_array_api_strict_flags() -> dict[str, Any]: """ Get the current array-api-strict flags. @@ -228,7 +248,7 @@ def get_array_api_strict_flags(): } -def reset_array_api_strict_flags(): +def reset_array_api_strict_flags() -> None: """ Reset the array-api-strict flags to their default values. @@ -300,8 +320,19 @@ class ArrayAPIStrictFlags: reset_array_api_strict_flags: Reset the flags to their default values. """ - def __init__(self, *, api_version=None, boolean_indexing=None, - data_dependent_shapes=None, enabled_extensions=None): + + kwargs: dict[str, Any] + old_flags: dict[str, Any] + __slots__ = ("kwargs", "old_flags") + + def __init__( + self, + *, + api_version: str | None = None, + boolean_indexing: bool | None = None, + data_dependent_shapes: bool | None = None, + enabled_extensions: Collection[str] | None = None, + ): self.kwargs = { "api_version": api_version, "boolean_indexing": boolean_indexing, @@ -310,12 +341,19 @@ def __init__(self, *, api_version=None, boolean_indexing=None, } self.old_flags = get_array_api_strict_flags() - def __enter__(self): + def __enter__(self) -> None: set_array_api_strict_flags(**self.kwargs) - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + /, + ) -> None: set_array_api_strict_flags(**self.old_flags) + # Private functions ENVIRONMENT_VARIABLES = [ @@ -325,8 +363,9 @@ def __exit__(self, exc_type, exc_value, traceback): "ARRAY_API_STRICT_ENABLED_EXTENSIONS", ] -def set_flags_from_environment(): - kwargs = {} + +def set_flags_from_environment() -> None: + kwargs: dict[str, Any] = {} if "ARRAY_API_STRICT_API_VERSION" in os.environ: kwargs["api_version"] = os.environ["ARRAY_API_STRICT_API_VERSION"] @@ -346,35 +385,41 @@ def set_flags_from_environment(): # linalg and fft to __all__ set_array_api_strict_flags(**kwargs) + set_flags_from_environment() # Decorators -def requires_api_version(version): - def decorator(func): + +def requires_api_version(version: str) -> Callable[[_CallableT], _CallableT]: + def decorator(func: Callable[P, T]) -> Callable[P, T]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if version > API_VERSION: raise RuntimeError( f"The function {func.__name__} requires API version {version} or later, " f"but the current API version for array-api-strict is {API_VERSION}" ) return func(*args, **kwargs) + return wrapper - return decorator -def requires_data_dependent_shapes(func): + return cast(Callable[[_CallableT], _CallableT], decorator) + + +def requires_data_dependent_shapes(func: Callable[P, T]) -> Callable[P, T]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if not DATA_DEPENDENT_SHAPES: raise RuntimeError(f"The function {func.__name__} requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") return func(*args, **kwargs) return wrapper -def requires_extension(extension): - def decorator(func): + +def requires_extension(extension: str) -> Callable[[_CallableT], _CallableT]: + def decorator(func: Callable[P, T]) -> Callable[P, T]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if extension not in ENABLED_EXTENSIONS: if extension == 'linalg' \ and func.__name__ in ['matmul', 'tensordot', @@ -382,5 +427,7 @@ def wrapper(*args, **kwargs): raise RuntimeError(f"The linalg extension has been disabled for array-api-strict. However, {func.__name__} is also present in the main array_api_strict namespace and may be used from there.") raise RuntimeError(f"The function {func.__name__} requires the {extension} extension, but it has been disabled for array-api-strict") return func(*args, **kwargs) + return wrapper - return decorator + + return cast(Callable[[_CallableT], _CallableT], decorator) diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py index d3fc9c9..291082e 100644 --- a/array_api_strict/_helpers.py +++ b/array_api_strict/_helpers.py @@ -1,18 +1,24 @@ -"""Private helper routines. -""" +"""Private helper routines.""" -from ._flags import get_array_api_strict_flags +from __future__ import annotations + +from ._array_object import Array from ._dtypes import _dtype_categories +from ._flags import get_array_api_strict_flags _py_scalars = (bool, int, float, complex) -def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name): - +def _maybe_normalize_py_scalars( + x1: Array | bool | int | float | complex, + x2: Array | bool | int | float | complex, + dtype_category: str, + func_name: str, +) -> tuple[Array, Array]: flags = get_array_api_strict_flags() if flags["api_version"] < "2024.12": # scalars will fail at the call site - return x1, x2 + return x1, x2 # type: ignore[return-value] _allowed_dtypes = _dtype_categories[dtype_category] @@ -34,4 +40,3 @@ def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name): raise TypeError(f"Only {dtype_category} dtypes are allowed in {func_name}(...). " f"Got {x1.dtype} and {x2.dtype}.") return x1, x2 - diff --git a/array_api_strict/_indexing_functions.py b/array_api_strict/_indexing_functions.py index d7a400e..ab25fab 100644 --- a/array_api_strict/_indexing_functions.py +++ b/array_api_strict/_indexing_functions.py @@ -1,17 +1,13 @@ from __future__ import annotations +import numpy as np + from ._array_object import Array from ._dtypes import _integer_dtypes from ._flags import requires_api_version -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional -import numpy as np - -def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: +def take(x: Array, indices: Array, /, *, axis: int | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.take `. @@ -27,6 +23,7 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device) + @requires_api_version('2024.12') def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: """ diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index a9dbebf..81f88bc 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -1,25 +1,22 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import numpy as np -if TYPE_CHECKING: - from typing import Optional, Union, Tuple, List - from ._typing import device, DefaultDataTypes, DataTypes, Capabilities - -from ._array_object import ALL_DEVICES, CPU_DEVICE +from . import _dtypes as dt +from ._array_object import ALL_DEVICES, CPU_DEVICE, Device from ._flags import get_array_api_strict_flags, requires_api_version -from ._dtypes import bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 +from ._typing import Capabilities, DataTypes, DefaultDataTypes + @requires_api_version('2023.12') class __array_namespace_info__: @requires_api_version('2023.12') def capabilities(self) -> Capabilities: flags = get_array_api_strict_flags() - res = {"boolean indexing": flags['boolean_indexing'], - "data-dependent shapes": flags['data_dependent_shapes'], - } + res: Capabilities = { # type: ignore[typeddict-item] + "boolean indexing": flags['boolean_indexing'], + "data-dependent shapes": flags['data_dependent_shapes'], + } if flags['api_version'] >= '2024.12': # maxdims is 32 for NumPy 1.x and 64 for NumPy 2.0. Eventually we will # drop support for NumPy 1 but for now, just compute the number @@ -36,104 +33,104 @@ def capabilities(self) -> Capabilities: return res @requires_api_version('2023.12') - def default_device(self) -> device: + def default_device(self) -> Device: return CPU_DEVICE @requires_api_version('2023.12') def default_dtypes( self, *, - device: Optional[device] = None, + device: Device | None = None, ) -> DefaultDataTypes: return { - "real floating": float64, - "complex floating": complex128, - "integral": int64, - "indexing": int64, + "real floating": dt.float64, + "complex floating": dt.complex128, + "integral": dt.int64, + "indexing": dt.int64, } @requires_api_version('2023.12') def dtypes( self, *, - device: Optional[device] = None, - kind: Optional[Union[str, Tuple[str, ...]]] = None, + device: Device | None = None, + kind: str | tuple[str, ...] | None = None, ) -> DataTypes: if kind is None: return { - "bool": bool, - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - "float32": float32, - "float64": float64, - "complex64": complex64, - "complex128": complex128, + "bool": dt.bool, + "int8": dt.int8, + "int16": dt.int16, + "int32": dt.int32, + "int64": dt.int64, + "uint8": dt.uint8, + "uint16": dt.uint16, + "uint32": dt.uint32, + "uint64": dt.uint64, + "float32": dt.float32, + "float64": dt.float64, + "complex64": dt.complex64, + "complex128": dt.complex128, } if kind == "bool": - return {"bool": bool} + return {"bool": dt.bool} if kind == "signed integer": return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, + "int8": dt.int8, + "int16": dt.int16, + "int32": dt.int32, + "int64": dt.int64, } if kind == "unsigned integer": return { - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, + "uint8": dt.uint8, + "uint16": dt.uint16, + "uint32": dt.uint32, + "uint64": dt.uint64, } if kind == "integral": return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, + "int8": dt.int8, + "int16": dt.int16, + "int32": dt.int32, + "int64": dt.int64, + "uint8": dt.uint8, + "uint16": dt.uint16, + "uint32": dt.uint32, + "uint64": dt.uint64, } if kind == "real floating": return { - "float32": float32, - "float64": float64, + "float32": dt.float32, + "float64": dt.float64, } if kind == "complex floating": return { - "complex64": complex64, - "complex128": complex128, + "complex64": dt.complex64, + "complex128": dt.complex128, } if kind == "numeric": return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "uint16": uint16, - "uint32": uint32, - "uint64": uint64, - "float32": float32, - "float64": float64, - "complex64": complex64, - "complex128": complex128, + "int8": dt.int8, + "int16": dt.int16, + "int32": dt.int32, + "int64": dt.int64, + "uint8": dt.uint8, + "uint16": dt.uint16, + "uint32": dt.uint32, + "uint64": dt.uint64, + "float32": dt.float32, + "float64": dt.float64, + "complex64": dt.complex64, + "complex128": dt.complex128, } if isinstance(kind, tuple): - res = {} + res: DataTypes = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") @requires_api_version('2023.12') - def devices(self) -> List[device]: + def devices(self) -> list[Device]: return list(ALL_DEVICES) diff --git a/array_api_strict/_linalg.py b/array_api_strict/_linalg.py index 7d379a0..27a2ddf 100644 --- a/array_api_strict/_linalg.py +++ b/array_api_strict/_linalg.py @@ -1,33 +1,25 @@ from __future__ import annotations +from collections.abc import Sequence from functools import partial +from typing import Literal, NamedTuple -from ._dtypes import ( - _floating_dtypes, - _numeric_dtypes, - float32, - complex64, - complex128, -) +import numpy as np +import numpy.linalg + +from ._array_object import Array from ._data_type_functions import finfo -from ._manipulation_functions import reshape +from ._dtypes import DType, _floating_dtypes, _numeric_dtypes, complex64, complex128 from ._elementwise_functions import conj -from ._array_object import Array -from ._flags import requires_extension, get_array_api_strict_flags +from ._flags import get_array_api_strict_flags, requires_extension +from ._manipulation_functions import reshape +from ._statistical_functions import _np_dtype_sumprod try: - from numpy._core.numeric import normalize_axis_tuple + from numpy._core.numeric import normalize_axis_tuple # type: ignore[attr-defined] except ImportError: - from numpy.core.numeric import normalize_axis_tuple + from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ._typing import Literal, Optional, Sequence, Tuple, Union, Dtype - -from typing import NamedTuple - -import numpy.linalg -import numpy as np class EighResult(NamedTuple): eigenvalues: Array @@ -175,7 +167,13 @@ def inv(x: Array, /) -> Array: # -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point # literals. @requires_extension('linalg') -def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array: # noqa: F821 +def matrix_norm( + x: Array, + /, + *, + keepdims: bool = False, + ord: float | Literal["fro", "nuc"] | None = "fro", +) -> Array: # noqa: F821 """ Array API compatible wrapper for :py:func:`np.linalg.norm `. @@ -186,7 +184,10 @@ def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in matrix_norm') - return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord), device=x.device) + return Array._new( + np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord), + device=x.device, + ) @requires_extension('linalg') @@ -206,7 +207,7 @@ def matrix_power(x: Array, n: int, /) -> Array: # Note: the keyword argument name rtol is different from np.linalg.matrix_rank @requires_extension('linalg') -def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: +def matrix_rank(x: Array, /, *, rtol: float | Array | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.matrix_rank `. @@ -218,13 +219,12 @@ def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> A raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") S = np.linalg.svd(x._array, compute_uv=False) if rtol is None: - tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * finfo(S.dtype).eps + tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps else: - if isinstance(rtol, Array): - rtol = rtol._array + rtol_np = rtol._array if isinstance(rtol, Array) else np.asarray(rtol) # Note: this is different from np.linalg.matrix_rank, which does not multiply # the tolerance by the largest singular value. - tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis] + tol = S.max(axis=-1, keepdims=True) * rtol_np[..., np.newaxis] return Array._new(np.count_nonzero(S > tol, axis=-1), device=x.device) @@ -252,7 +252,7 @@ def outer(x1: Array, x2: Array, /) -> Array: # Note: the keyword argument name rtol is different from np.linalg.pinv @requires_extension('linalg') -def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: +def pinv(x: Array, /, *, rtol: float | Array | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.pinv `. @@ -267,9 +267,8 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: # default tolerance by max(M, N). if rtol is None: rtol = max(x.shape[-2:]) * finfo(x.dtype).eps - if isinstance(rtol, Array): - rtol = rtol._array - return Array._new(np.linalg.pinv(x._array, rcond=rtol), device=x.device) + rtol_np = rtol._array if isinstance(rtol, Array) else rtol + return Array._new(np.linalg.pinv(x._array, rcond=rtol_np), device=x.device) @requires_extension('linalg') def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: # noqa: F821 @@ -312,14 +311,14 @@ def slogdet(x: Array, /) -> SlogdetResult: # To workaround this, the below is the code from np.linalg.solve except # only calling solve1 in the exactly 1D case. -def _solve(a, b): +def _solve(a: np.ndarray, b: np.ndarray) -> np.ndarray: try: - from numpy.linalg._linalg import ( + from numpy.linalg._linalg import ( # type: ignore[attr-defined] _makearray, _assert_stacked_2d, _assert_stacked_square, _commonType, isComplexType, _raise_linalgerror_singular ) except ImportError: - from numpy.linalg.linalg import ( + from numpy.linalg.linalg import ( # type: ignore[attr-defined] _makearray, _assert_stacked_2d, _assert_stacked_square, _commonType, isComplexType, _raise_linalgerror_singular ) @@ -382,14 +381,14 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). @requires_extension('linalg') -def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: +def svdvals(x: Array, /) -> Array: if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in svdvals') return Array._new(np.linalg.svd(x._array, compute_uv=False), device=x.device) # Note: trace is the numpy top-level namespace, not np.linalg @requires_extension('linalg') -def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Array: +def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.trace `. @@ -398,19 +397,13 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in trace') - # Note: trace() works the same as sum() and prod() (see - # _statistical_functions.py) - if dtype is None: - if get_array_api_strict_flags()['api_version'] < '2023.12': - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 - else: - dtype = dtype._np_dtype + # Note: trace() works the same as sum() and prod() (see _statistical_functions.py) + np_dtype = _np_dtype_sumprod(x, dtype) + # Note: trace always operates on the last two axes, whereas np.trace # operates on the first two axes by default - return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype)), device=x.device) + res = np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=np_dtype) + return Array._new(np.asarray(res), device=x.device) # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). @@ -418,7 +411,14 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf]]] but Literal does not support floating-point literals. @requires_extension('linalg') -def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: +def vector_norm( + x: Array, + /, + *, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, + ord: int | float = 2, +) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm `. @@ -456,8 +456,8 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) - _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) - for i in _axis: + axis_tup = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) + for i in axis_tup: shape[i] = 1 res = reshape(res, tuple(shape)) @@ -480,7 +480,13 @@ def matmul(x1: Array, x2: Array, /) -> Array: # Note: tensordot is the numpy top-level namespace but not in np.linalg @requires_extension('linalg') -def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, +) -> Array: from ._linear_algebra_functions import tensordot return tensordot(x1, x2, axes=axes) diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py index 6af2a15..d18214c 100644 --- a/array_api_strict/_linear_algebra_functions.py +++ b/array_api_strict/_linear_algebra_functions.py @@ -7,16 +7,15 @@ from __future__ import annotations -from ._dtypes import _numeric_dtypes +from collections.abc import Sequence + +import numpy as np +import numpy.linalg + from ._array_object import Array +from ._dtypes import _numeric_dtypes from ._flags import get_array_api_strict_flags -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ._typing import Sequence, Tuple, Union - -import numpy.linalg -import numpy as np # Note: matmul is the numpy top-level namespace but not in np.linalg def matmul(x1: Array, x2: Array, /) -> Array: @@ -38,7 +37,13 @@ def matmul(x1: Array, x2: Array, /) -> Array: # Note: tensordot is the numpy top-level namespace but not in np.linalg # Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. -def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, +) -> Array: # Note: the restriction to numeric dtypes only is different from # np.tensordot. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index 63c3516..e2fd24c 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -1,21 +1,17 @@ from __future__ import annotations +import numpy as np + from ._array_object import Array from ._creation_functions import asarray from ._data_type_functions import astype, result_type from ._dtypes import _integer_dtypes, int64, uint64 -from ._flags import requires_api_version, get_array_api_strict_flags - -from typing import TYPE_CHECKING +from ._flags import get_array_api_strict_flags, requires_api_version -if TYPE_CHECKING: - from typing import List, Optional, Tuple, Union - -import numpy as np # Note: the function name is different here def concat( - arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0 + arrays: tuple[Array, ...] | list[Array], /, *, axis: int | None = 0 ) -> Array: """ Array API compatible wrapper for :py:func:`np.concatenate `. @@ -29,8 +25,11 @@ def concat( raise ValueError("concat inputs must all be on the same device") result_device = arrays[0].device - arrays = tuple(a._array for a in arrays) - return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype._np_dtype), device=result_device) + np_arrays = tuple(a._array for a in arrays) + return Array._new( + np.concatenate(np_arrays, axis=axis, dtype=dtype._np_dtype), + device=result_device, + ) def expand_dims(x: Array, /, *, axis: int) -> Array: @@ -42,7 +41,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array: return Array._new(np.expand_dims(x._array, axis), device=x.device) -def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: +def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.flip `. @@ -53,8 +52,8 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> @requires_api_version('2023.12') def moveaxis( x: Array, - source: Union[int, Tuple[int, ...]], - destination: Union[int, Tuple[int, ...]], + source: int | tuple[int, ...], + destination: int | tuple[int, ...], /, ) -> Array: """ @@ -66,7 +65,7 @@ def moveaxis( # Note: The function name is different here (see also matrix_transpose). # Unlike transpose(), the axes argument is required. -def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: +def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.transpose `. @@ -77,10 +76,10 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: @requires_api_version('2023.12') def repeat( x: Array, - repeats: Union[int, Array], + repeats: int | Array, /, *, - axis: Optional[int] = None, + axis: int | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.repeat `. @@ -108,12 +107,9 @@ def repeat( repeats = astype(repeats, int64) return Array._new(np.repeat(x._array, repeats._array, axis=axis), device=x.device) + # Note: the optional argument is called 'shape', not 'newshape' -def reshape(x: Array, - /, - shape: Tuple[int, ...], - *, - copy: Optional[bool] = None) -> Array: +def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array: """ Array API compatible wrapper for :py:func:`np.reshape `. @@ -135,9 +131,9 @@ def reshape(x: Array, def roll( x: Array, /, - shift: Union[int, Tuple[int, ...]], + shift: int | tuple[int, ...], *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.roll `. @@ -147,7 +143,7 @@ def roll( return Array._new(np.roll(x._array, shift, axis=axis), device=x.device) -def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: +def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.squeeze `. @@ -161,7 +157,7 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: return Array._new(np.squeeze(x._array, axis=axis), device=x.device) -def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array: +def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array: """ Array API compatible wrapper for :py:func:`np.stack `. @@ -172,12 +168,12 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> if len({a.device for a in arrays}) > 1: raise ValueError("concat inputs must all be on the same device") result_device = arrays[0].device - arrays = tuple(a._array for a in arrays) - return Array._new(np.stack(arrays, axis=axis), device=result_device) + np_arrays = tuple(a._array for a in arrays) + return Array._new(np.stack(np_arrays, axis=axis), device=result_device) @requires_api_version('2023.12') -def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array: +def tile(x: Array, repetitions: tuple[int, ...], /) -> Array: """ Array API compatible wrapper for :py:func:`np.tile `. @@ -190,7 +186,7 @@ def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array: # Note: this function is new @requires_api_version('2023.12') -def unstack(x: Array, /, *, axis: int = 0) -> Tuple[Array, ...]: +def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]: if not (-x.ndim <= axis < x.ndim): raise ValueError("axis out of range") diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 9864132..b366ed9 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -1,18 +1,17 @@ from __future__ import annotations -from ._array_object import Array -from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool -from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags -from ._helpers import _maybe_normalize_py_scalars - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Literal, Optional, Tuple, Union +from typing import Literal import numpy as np +from ._array_object import Array +from ._dtypes import _real_numeric_dtypes, _result_type +from ._dtypes import bool as _bool +from ._flags import requires_api_version, requires_data_dependent_shapes +from ._helpers import _maybe_normalize_py_scalars + -def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: +def argmax(x: Array, /, *, axis: int | None = None, keepdims: bool = False) -> Array: """ Array API compatible wrapper for :py:func:`np.argmax `. @@ -23,7 +22,7 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims)), device=x.device) -def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array: +def argmin(x: Array, /, *, axis: int | None = None, keepdims: bool = False) -> Array: """ Array API compatible wrapper for :py:func:`np.argmin `. @@ -35,7 +34,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - @requires_data_dependent_shapes -def nonzero(x: Array, /) -> Tuple[Array, ...]: +def nonzero(x: Array, /) -> tuple[Array, ...]: """ Array API compatible wrapper for :py:func:`np.nonzero `. @@ -52,7 +51,7 @@ def count_nonzero( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: """ @@ -71,7 +70,7 @@ def searchsorted( /, *, side: Literal["left", "right"] = "left", - sorter: Optional[Array] = None, + sorter: Array | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.searchsorted `. @@ -84,25 +83,29 @@ def searchsorted( if x1.device != x2.device: raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.") - sorter = sorter._array if sorter is not None else None + np_sorter = sorter._array if sorter is not None else None # TODO: The sort order of nans and signed zeros is implementation # dependent. Should we error/warn if they are present? # x1 must be 1-D, but NumPy already requires this. - return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device) + return Array._new( + np.searchsorted(x1._array, x2._array, side=side, sorter=np_sorter), + device=x1.device, + ) + def where( condition: Array, - x1: bool | int | float | complex | Array, - x2: bool | int | float | complex | Array, / + x1: Array | bool | int | float | complex, + x2: Array | bool | int | float | complex, + /, ) -> Array: """ Array API compatible wrapper for :py:func:`np.where `. See its docstring for more information. """ - if get_array_api_strict_flags()['api_version'] > '2023.12': - x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where") + x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) diff --git a/array_api_strict/_set_functions.py b/array_api_strict/_set_functions.py index 7bd5bad..e677a52 100644 --- a/array_api_strict/_set_functions.py +++ b/array_api_strict/_set_functions.py @@ -1,13 +1,12 @@ from __future__ import annotations -from ._array_object import Array - -from ._flags import requires_data_dependent_shapes - from typing import NamedTuple import numpy as np +from ._array_object import Array +from ._flags import requires_data_dependent_shapes + # Note: np.unique() is split into four functions in the array API: # unique_all, unique_counts, unique_inverse, and unique_values (this is done # to remove polymorphic return types). @@ -20,6 +19,7 @@ # Note: The functions here return a namedtuple (np.unique() returns a normal # tuple). + class UniqueAllResult(NamedTuple): values: Array indices: Array diff --git a/array_api_strict/_sorting_functions.py b/array_api_strict/_sorting_functions.py index 765bd9e..e9193f1 100644 --- a/array_api_strict/_sorting_functions.py +++ b/array_api_strict/_sorting_functions.py @@ -1,10 +1,12 @@ from __future__ import annotations -from ._array_object import Array -from ._dtypes import _real_numeric_dtypes +from typing import Literal import numpy as np +from ._array_object import Array +from ._dtypes import _real_numeric_dtypes + # Note: the descending keyword argument is new in this function def argsort( @@ -18,7 +20,7 @@ def argsort( if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in argsort") # Note: this keyword argument is different, and the default is different. - kind = "stable" if stable else "quicksort" + kind: Literal["stable", "quicksort"] = "stable" if stable else "quicksort" if not descending: res = np.argsort(x._array, axis=axis, kind=kind) else: @@ -35,6 +37,7 @@ def argsort( res = max_i - res return Array._new(res, device=x.device) + # Note: the descending keyword argument is new in this function def sort( x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True @@ -47,8 +50,7 @@ def sort( if x.dtype not in _real_numeric_dtypes: raise TypeError("Only real numeric dtypes are allowed in sort") # Note: this keyword argument is different, and the default is different. - kind = "stable" if stable else "quicksort" - res = np.sort(x._array, axis=axis, kind=kind) + res = np.sort(x._array, axis=axis, kind="stable" if stable else "quicksort") if descending: res = np.flip(res, axis=axis) return Array._new(res, device=x.device) diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index e41e7ef..668cd02 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -1,38 +1,36 @@ from __future__ import annotations +from typing import Any + +import numpy as np + +from ._array_object import Array +from ._creation_functions import ones, zeros from ._dtypes import ( - _real_floating_dtypes, - _real_numeric_dtypes, + DType, _floating_dtypes, + _np_dtype, _numeric_dtypes, + _real_floating_dtypes, + _real_numeric_dtypes, + complex64, + float32, ) -from ._array_object import Array -from ._dtypes import float32, complex64 -from ._flags import requires_api_version, get_array_api_strict_flags -from ._creation_functions import zeros, ones +from ._flags import get_array_api_strict_flags, requires_api_version from ._manipulation_functions import concat -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Tuple, Union - from ._typing import Dtype - -import numpy as np @requires_api_version('2023.12') def cumulative_sum( x: Array, /, *, - axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in cumulative_sum") - if dtype is not None: - dtype = dtype._np_dtype # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: @@ -44,7 +42,7 @@ def cumulative_sum( if axis < 0: axis += x.ndim x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis) - return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device) + return Array._new(np.cumsum(x._array, axis=axis, dtype=_np_dtype(dtype)), device=x.device) @requires_api_version('2024.12') @@ -52,8 +50,8 @@ def cumulative_prod( x: Array, /, *, - axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, ) -> Array: if x.dtype not in _numeric_dtypes: @@ -61,9 +59,6 @@ def cumulative_prod( if x.ndim == 0: raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod") - if dtype is not None: - dtype = dtype._np_dtype - if axis is None: if x.ndim > 1: raise ValueError("axis must be specified in cumulative_prod for more than one dimension") @@ -74,14 +69,14 @@ def cumulative_prod( if axis < 0: axis += x.ndim x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis) - return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype), device=x.device) + return Array._new(np.cumprod(x._array, axis=axis, dtype=_np_dtype(dtype)), device=x.device) def max( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: if x.dtype not in _real_numeric_dtypes: @@ -93,14 +88,15 @@ def mean( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: - if get_array_api_strict_flags()['api_version'] > '2023.12': - allowed_dtypes = _floating_dtypes - else: - allowed_dtypes = _real_floating_dtypes + allowed_dtypes = ( + _floating_dtypes + if get_array_api_strict_flags()['api_version'] > '2023.12' + else _real_floating_dtypes + ) if x.dtype not in allowed_dtypes: raise TypeError("Only floating-point dtypes are allowed in mean") @@ -111,7 +107,7 @@ def min( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: if x.dtype not in _real_numeric_dtypes: @@ -119,37 +115,43 @@ def min( return Array._new(np.min(x._array, axis=axis, keepdims=keepdims), device=x.device) +def _np_dtype_sumprod(x: Array, dtype: DType | None) -> np.dtype[Any] | None: + """In versions prior to 2023.12, sum() and prod() upcast for all + dtypes when dtype=None. For 2023.12, the behavior is the same as in + NumPy (only upcast for integral dtypes). + """ + if dtype is None and get_array_api_strict_flags()['api_version'] < '2023.12': + if x.dtype == float32: + return np.float64 # type: ignore[return-value] + elif x.dtype == complex64: + return np.complex128 # type: ignore[return-value] + return _np_dtype(dtype) + + def prod( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in prod") - if dtype is None: - # Note: In versions prior to 2023.12, sum() and prod() upcast for all - # dtypes when dtype=None. For 2023.12, the behavior is the same as in - # NumPy (only upcast for integral dtypes). - if get_array_api_strict_flags()['api_version'] < '2023.12': - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 - else: - dtype = dtype._np_dtype - return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims), device=x.device) + np_dtype = _np_dtype_sumprod(x, dtype) + return Array._new( + np.prod(x._array, dtype=np_dtype, axis=axis, keepdims=keepdims), + device=x.device, + ) def std( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: int | float = 0.0, keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here @@ -162,33 +164,26 @@ def sum( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sum") - if dtype is None: - # Note: In versions prior to 2023.12, sum() and prod() upcast for all - # dtypes when dtype=None. For 2023.12, the behavior is the same as in - # NumPy (only upcast for integral dtypes). - if get_array_api_strict_flags()['api_version'] < '2023.12': - if x.dtype == float32: - dtype = np.float64 - elif x.dtype == complex64: - dtype = np.complex128 - else: - dtype = dtype._np_dtype - return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims), device=x.device) + np_dtype = _np_dtype_sumprod(x, dtype) + return Array._new( + np.sum(x._array, axis=axis, dtype=np_dtype, keepdims=keepdims), + device=x.device, + ) def var( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: int | float = 0.0, keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here diff --git a/array_api_strict/_typing.py b/array_api_strict/_typing.py index 94c4975..91095a8 100644 --- a/array_api_strict/_typing.py +++ b/array_api_strict/_typing.py @@ -8,41 +8,19 @@ from __future__ import annotations -__all__ = [ - "Array", - "Device", - "Dtype", - "SupportsDLPack", - "SupportsBufferProtocol", - "PyCapsule", -] - import sys +from typing import Any, Protocol, TypedDict, TypeVar -from typing import ( - Any, - TypedDict, - TypeVar, - Protocol, -) - -from ._array_object import Array, _device -from ._dtypes import _DType -from ._info import __array_namespace_info__ +from ._dtypes import DType _T_co = TypeVar("_T_co", covariant=True) + class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -Device = _device - -Dtype = _DType - -Info = __array_namespace_info__ - if sys.version_info >= (3, 12): from collections.abc import Buffer as SupportsBufferProtocol else: @@ -50,40 +28,42 @@ def __len__(self, /) -> int: ... PyCapsule = Any + class SupportsDLPack(Protocol): def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ... + Capabilities = TypedDict( - "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool, - "max dimensions": int} + "Capabilities", + { + "boolean indexing": bool, + "data-dependent shapes": bool, + "max dimensions": int, + }, ) DefaultDataTypes = TypedDict( "DefaultDataTypes", { - "real floating": Dtype, - "complex floating": Dtype, - "integral": Dtype, - "indexing": Dtype, + "real floating": DType, + "complex floating": DType, + "integral": DType, + "indexing": DType, }, ) -DataTypes = TypedDict( - "DataTypes", - { - "bool": Dtype, - "float32": Dtype, - "float64": Dtype, - "complex64": Dtype, - "complex128": Dtype, - "int8": Dtype, - "int16": Dtype, - "int32": Dtype, - "int64": Dtype, - "uint8": Dtype, - "uint16": Dtype, - "uint32": Dtype, - "uint64": Dtype, - }, - total=False, -) + +class DataTypes(TypedDict, total=False): + bool: DType + float32: DType + float64: DType + complex64: DType + complex128: DType + int8: DType + int16: DType + int32: DType + int64: DType + uint8: DType + uint16: DType + uint32: DType + uint64: DType diff --git a/array_api_strict/_utility_functions.py b/array_api_strict/_utility_functions.py index f75f36f..fab1025 100644 --- a/array_api_strict/_utility_functions.py +++ b/array_api_strict/_utility_functions.py @@ -1,21 +1,20 @@ from __future__ import annotations -from ._array_object import Array -from ._flags import requires_api_version -from ._dtypes import _numeric_dtypes - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Tuple, Union +from typing import Any import numpy as np +import numpy.typing as npt + +from ._array_object import Array +from ._dtypes import _numeric_dtypes +from ._flags import requires_api_version def all( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: """ @@ -30,7 +29,7 @@ def any( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: """ @@ -40,6 +39,7 @@ def any( """ return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims)), device=x.device) + @requires_api_version('2024.12') def diff( x: Array, @@ -47,8 +47,8 @@ def diff( *, axis: int = -1, n: int = 1, - prepend: Optional[Array] = None, - append: Optional[Array] = None, + prepend: Array | None = None, + append: Array | None = None, ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in diff") @@ -57,7 +57,7 @@ def diff( # currently specified. # NumPy does not support prepend=None or append=None - kwargs = dict(axis=axis, n=n) + kwargs: dict[str, int | npt.NDArray[Any]] = {"axis": axis, "n": n} if prepend is not None: if prepend.device != x.device: raise ValueError(f"Arrays from two different devices ({prepend.device} and {x.device}) can not be combined.") diff --git a/array_api_strict/py.typed b/array_api_strict/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/array_api_strict/tests/test_validation.py b/array_api_strict/tests/test_validation.py index bd76ec6..5552e3a 100644 --- a/array_api_strict/tests/test_validation.py +++ b/array_api_strict/tests/test_validation.py @@ -1,11 +1,9 @@ -from typing import Callable - import pytest import array_api_strict as xp -def p(func: Callable, *args, **kwargs): +def p(func, *args, **kwargs): f_sig = ", ".join( [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs.items()] ) diff --git a/pyproject.toml b/pyproject.toml index b3d2594..cf9e6dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,3 +31,18 @@ Repository = "https://github.com/data-apis/array-api-strict" [tool.setuptools_scm] version_file = "array_api_strict/_version.py" +[tool.mypy] +disallow_incomplete_defs = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +no_implicit_optional = true +show_error_codes = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true +strict_bytes = true +local_partial_types = true + +[[tool.mypy.overrides]] +module = ["*.tests.*"] +disallow_untyped_defs = false From e7c81bc64ef31d8333495c96b24d39300c8eae53 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 1 Apr 2025 09:32:33 +0100 Subject: [PATCH 239/252] Drop Python 3.9 support --- array_api_strict/_array_object.py | 12 ++---------- array_api_strict/_data_type_functions.py | 2 -- array_api_strict/_dtypes.py | 2 -- array_api_strict/_elementwise_functions.py | 2 -- array_api_strict/_fft.py | 2 -- array_api_strict/_flags.py | 12 ++---------- array_api_strict/_helpers.py | 2 -- array_api_strict/_indexing_functions.py | 2 -- array_api_strict/_info.py | 2 -- array_api_strict/_linalg.py | 2 -- array_api_strict/_linear_algebra_functions.py | 3 --- array_api_strict/_manipulation_functions.py | 2 -- array_api_strict/_searching_functions.py | 2 -- array_api_strict/_set_functions.py | 2 -- array_api_strict/_sorting_functions.py | 2 -- array_api_strict/_statistical_functions.py | 2 -- array_api_strict/_utility_functions.py | 2 -- pyproject.toml | 5 ++--- 18 files changed, 6 insertions(+), 54 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 1304d5a..2373c89 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -16,11 +16,10 @@ from __future__ import annotations import operator -import sys from collections.abc import Iterator from enum import IntEnum -from types import ModuleType -from typing import TYPE_CHECKING, Any, Final, Literal, SupportsIndex +from types import EllipsisType, ModuleType +from typing import Any, Final, Literal, SupportsIndex import numpy as np import numpy.typing as npt @@ -43,13 +42,6 @@ from ._flags import get_array_api_strict_flags, set_array_api_strict_flags from ._typing import PyCapsule -if sys.version_info >= (3, 10): - from types import EllipsisType -elif TYPE_CHECKING: - from typing_extensions import EllipsisType -else: - EllipsisType = type(Ellipsis) - class Device: _device: Final[str] diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 7dc918d..16795fc 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from dataclasses import dataclass import numpy as np diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index 513650b..7bed828 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import builtins import warnings from typing import Any, Final diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 6b52a58..b05e0fd 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import numpy as np from ._array_object import Array diff --git a/array_api_strict/_fft.py b/array_api_strict/_fft.py index 2998254..c2c617e 100644 --- a/array_api_strict/_fft.py +++ b/array_api_strict/_fft.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from collections.abc import Sequence from typing import Literal diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 6729a4b..4a099ce 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -11,24 +11,16 @@ library will only support one particular configuration of these flags. """ - -from __future__ import annotations - import functools import os import warnings from collections.abc import Callable from types import TracebackType -from typing import TYPE_CHECKING, Any, Collection, TypeVar, cast +from typing import Any, Collection, ParamSpec, TypeVar, cast import array_api_strict -if TYPE_CHECKING: - # TODO import from typing (requires Python >= 3.10) - from typing_extensions import ParamSpec - - P = ParamSpec("P") - +P = ParamSpec("P") T = TypeVar("T") _CallableT = TypeVar("_CallableT", bound=Callable[..., object]) diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py index 291082e..e8c6767 100644 --- a/array_api_strict/_helpers.py +++ b/array_api_strict/_helpers.py @@ -1,7 +1,5 @@ """Private helper routines.""" -from __future__ import annotations - from ._array_object import Array from ._dtypes import _dtype_categories from ._flags import get_array_api_strict_flags diff --git a/array_api_strict/_indexing_functions.py b/array_api_strict/_indexing_functions.py index ab25fab..ce72ae4 100644 --- a/array_api_strict/_indexing_functions.py +++ b/array_api_strict/_indexing_functions.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import numpy as np from ._array_object import Array diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index 81f88bc..0eb6696 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import numpy as np from . import _dtypes as dt diff --git a/array_api_strict/_linalg.py b/array_api_strict/_linalg.py index 27a2ddf..72d7f0a 100644 --- a/array_api_strict/_linalg.py +++ b/array_api_strict/_linalg.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from collections.abc import Sequence from functools import partial from typing import Literal, NamedTuple diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py index d18214c..384267d 100644 --- a/array_api_strict/_linear_algebra_functions.py +++ b/array_api_strict/_linear_algebra_functions.py @@ -4,9 +4,6 @@ linalg extension is disabled in the flags. """ - -from __future__ import annotations - from collections.abc import Sequence import numpy as np diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index e2fd24c..fe4a608 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import numpy as np from ._array_object import Array diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index b366ed9..c42ccc7 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Literal import numpy as np diff --git a/array_api_strict/_set_functions.py b/array_api_strict/_set_functions.py index e677a52..8a88047 100644 --- a/array_api_strict/_set_functions.py +++ b/array_api_strict/_set_functions.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import NamedTuple import numpy as np diff --git a/array_api_strict/_sorting_functions.py b/array_api_strict/_sorting_functions.py index e9193f1..08275cf 100644 --- a/array_api_strict/_sorting_functions.py +++ b/array_api_strict/_sorting_functions.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Literal import numpy as np diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 668cd02..4160f7a 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Any import numpy as np diff --git a/array_api_strict/_utility_functions.py b/array_api_strict/_utility_functions.py index fab1025..bedb488 100644 --- a/array_api_strict/_utility_functions.py +++ b/array_api_strict/_utility_functions.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Any import numpy as np diff --git a/pyproject.toml b/pyproject.toml index cf9e6dd..00d09da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "array_api_strict" dynamic = ["version"] -requires-python = ">= 3.9" +requires-python = ">= 3.10" dependencies = ["numpy"] license = {file = "LICENSE"} authors = [ @@ -15,7 +15,6 @@ description = "A strict, minimal implementation of the Python array API standard readme = "README.md" classifiers = [ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -24,7 +23,7 @@ classifiers = [ "Operating System :: OS Independent", ] -[project-urls] +[project.urls] Homepage = "https://data-apis.org/array-api-strict/" Repository = "https://github.com/data-apis/array-api-strict" From ea5deb1af9a52cb4d1125fbddc1a14d8d86be9ed Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 2 Apr 2025 20:31:07 +1100 Subject: [PATCH 240/252] BUG: fix tuple array indexing reviewed at https://github.com/data-apis/array-api-strict/pull/139 --- array_api_strict/_array_object.py | 20 ++++++++- array_api_strict/tests/test_array_object.py | 46 ++++++++++++++++----- 2 files changed, 53 insertions(+), 13 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 2373c89..cb2dd11 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -716,8 +716,24 @@ def __getitem__( # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index self._validate_index(key, op="getitem") - # Indexing self._array with array_api_strict arrays can be erroneous - np_key = key._array if isinstance(key, Array) else key + if isinstance(key, Array): + key = (key,) + np_key = key + devices = {self.device} + if isinstance(key, tuple): + devices.update( + [subkey.device for subkey in key if hasattr(subkey, "device")] + ) + if len(devices) > 1: + raise ValueError( + "Array indexing is only allowed when array to be indexed and all " + "indexing arrays are on the same device." + ) + # Indexing self._array with array_api_strict arrays can be erroneous + # e.g., when using non-default device + np_key = tuple( + subkey._array if isinstance(subkey, Array) else subkey for subkey in key + ) res = self._array.__getitem__(np_key) return self._new(res, device=self.device) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index e24a40f..51f4f31 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from .. import ones, arange, reshape, asarray, result_type, all, equal +from .. import ones, arange, reshape, asarray, result_type, all, equal, stack from .._array_object import Array, CPU_DEVICE, Device from .._dtypes import ( _all_dtypes, @@ -101,33 +101,40 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[idx]) -def test_indexing_arrays(): +@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) +def test_indexing_arrays(device): # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed + device = None if device is None else Device(device) # 1D array - a = arange(5) - idx = asarray([1, 0, 1, 2, -1]) + a = arange(5, device=device) + idx = asarray([1, 0, 1, 2, -1], device=device) a_idx = a[idx] - a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])]) + a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) + assert a_idx.shape == idx.shape + assert a.device == idx.device == a_idx.device # setitem with arrays is not allowed with assert_raises(IndexError): a[idx] = 42 # mixed array and integer indexing - a = reshape(arange(3*4), (3, 4)) - idx = asarray([1, 0, 1, 2, -1]) + a = reshape(arange(3*4, device=device), (3, 4)) + idx = asarray([1, 0, 1, 2, -1], device=device) a_idx = a[idx, 1] - - a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])]) + a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) + assert a_idx.shape == idx.shape + assert a.device == idx.device == a_idx.device # index with two arrays a_idx = a[idx, idx] - a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])]) + a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])]) assert all(a_idx == a_idx_loop) + assert a_idx.shape == a_idx.shape + assert a.device == idx.device == a_idx.device # setitem with arrays is not allowed with assert_raises(IndexError): @@ -135,7 +142,24 @@ def test_indexing_arrays(): # smoke test indexing with ndim > 1 arrays idx = idx[..., None] - a[idx, idx] + a_idx = a[idx, idx] + assert a.device == idx.device == a_idx.device + + +def test_indexing_arrays_different_devices(): + # Ensure indexing via array on different device errors + device1 = Device("CPU_DEVICE") + device2 = Device("device1") + + a = arange(5, device=device1) + idx1 = asarray([1, 0, 1, 2, -1], device=device2) + idx2 = asarray([1, 0, 1, 2, -1], device=device1) + + with pytest.raises(ValueError, match="Array indexing is only allowed when"): + a[idx1] + + with pytest.raises(ValueError, match="Array indexing is only allowed when"): + a[idx1, idx2] def test_promoted_scalar_inherits_device(): From 72eabc47cbbd20c4436e05f6c545e5252eaafeaf Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 2 Apr 2025 11:37:39 +0100 Subject: [PATCH 241/252] MAINT: `result_type` cosmetic refactor --- array_api_strict/_data_type_functions.py | 34 +++++++++++------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 16795fc..c3c8462 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -204,35 +204,31 @@ def result_type( # required by the spec rather than using np.result_type. NumPy implements # too many extra type promotions like int64 + uint64 -> float64, and does # value-based casting on scalar arrays. - A = [] + dtypes = [] scalars = [] for a in arrays_and_dtypes: - if isinstance(a, Array): - a = a.dtype + if isinstance(a, DType): + dtypes.append(a) + elif isinstance(a, Array): + dtypes.append(a.dtype) elif isinstance(a, (bool, int, float, complex)): scalars.append(a) - elif isinstance(a, np.ndarray) or a not in _all_dtypes: - raise TypeError("result_type() inputs must be array_api arrays or dtypes") - A.append(a) - - # remove python scalars - B = [a for a in A if not isinstance(a, (bool, int, float, complex))] + else: + raise TypeError( + "result_type() inputs must be Array API arrays, dtypes, or scalars" + ) - if len(B) == 0: + if not dtypes: raise ValueError("at least one array or dtype is required") - elif len(B) == 1: - result = B[0] - else: - t = B[0] - for t2 in B[1:]: - t = _result_type(t, t2) - result = t + result = dtypes[0] + for t2 in dtypes[1:]: + result = _result_type(result, t2) - if len(scalars) == 0: + if not scalars: return result if get_array_api_strict_flags()['api_version'] <= '2023.12': - raise TypeError("result_type() inputs must be array_api arrays or dtypes") + raise TypeError("result_type() inputs must be Array API arrays or dtypes") # promote python scalars given the result_type for all arrays/dtypes from ._creation_functions import empty From f5778f6d1e75d5146ecbaa2082450194fd4073d0 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 2 Apr 2025 11:44:14 +0100 Subject: [PATCH 242/252] MAINT: finfo() / iinfo() input/output review --- array_api_strict/_data_type_functions.py | 26 +++++++++-- array_api_strict/_dtypes.py | 2 +- .../tests/test_data_type_functions.py | 46 ++++++++++++++++--- 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index c3c8462..e318724 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -100,7 +100,9 @@ def can_cast(from_: DType | Array, to: DType, /) -> bool: # These are internal objects for the return types of finfo and iinfo, since # the NumPy versions contain extra data that isn't part of the spec. -@dataclass +# There should be no expectation for them to be comparable, hashable, or writeable. + +@dataclass(frozen=True, eq=False) class finfo_object: bits: int # Note: The types of the float data here are float, whereas in NumPy they @@ -111,14 +113,18 @@ class finfo_object: smallest_normal: float dtype: DType + __hash__ = NotImplemented + -@dataclass +@dataclass(frozen=True, eq=False) class iinfo_object: bits: int max: int min: int dtype: DType + __hash__ = NotImplemented + def finfo(type: DType | Array, /) -> finfo_object: """ @@ -126,7 +132,13 @@ def finfo(type: DType | Array, /) -> finfo_object: See its docstring for more information. """ - np_type = type._array if isinstance(type, Array) else type._np_dtype + if isinstance(type, Array): + np_type = type._dtype._np_dtype + elif isinstance(type, DType): + np_type = type._np_dtype + else: + raise TypeError(f"'type' must be a dtype or array, not {type!r}") + fi = np.finfo(np_type) # Note: The types of the float data here are float, whereas in NumPy they # are scalars of the corresponding float dtype. @@ -146,7 +158,13 @@ def iinfo(type: DType | Array, /) -> iinfo_object: See its docstring for more information. """ - np_type = type._array if isinstance(type, Array) else type._np_dtype + if isinstance(type, Array): + np_type = type._dtype._np_dtype + elif isinstance(type, DType): + np_type = type._np_dtype + else: + raise TypeError(f"'type' must be a dtype or array, not {type!r}") + ii = np.iinfo(np_type) return iinfo_object(ii.bits, ii.max, ii.min, DType(ii.dtype)) diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index 7bed828..564db5a 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -35,7 +35,7 @@ def __eq__(self, other: object) -> builtins.bool: stacklevel=2, ) if not isinstance(other, DType): - return NotImplemented + return False return self._np_dtype == other._np_dtype def __hash__(self) -> int: diff --git a/array_api_strict/tests/test_data_type_functions.py b/array_api_strict/tests/test_data_type_functions.py index 919c0b4..7f24920 100644 --- a/array_api_strict/tests/test_data_type_functions.py +++ b/array_api_strict/tests/test_data_type_functions.py @@ -1,15 +1,12 @@ import warnings +import numpy as np import pytest - from numpy.testing import assert_raises -import numpy as np from .._creation_functions import asarray -from .._data_type_functions import astype, can_cast, isdtype, result_type -from .._dtypes import ( - bool, int8, int16, uint8, float64, int64 -) +from .._data_type_functions import astype, can_cast, finfo, iinfo, isdtype, result_type +from .._dtypes import bool, float64, int8, int16, int64, uint8 from .._flags import set_array_api_strict_flags @@ -88,3 +85,40 @@ def test_result_type_py_scalars(api_version): with pytest.raises(TypeError): result_type(int64, True) + + +def test_finfo_iinfo_dtypelike(): + """np.finfo() and np.iinfo() accept any DTypeLike. + Array API only accepts Array | DType. + """ + match = "must be a dtype or array" + with pytest.raises(TypeError, match=match): + finfo("float64") + with pytest.raises(TypeError, match=match): + finfo(float) + with pytest.raises(TypeError, match=match): + iinfo("int8") + with pytest.raises(TypeError, match=match): + iinfo(int) + + +def test_finfo_iinfo_wrap_output(): + """Test that the finfo(...).dtype and iinfo(...).dtype + are array-api-strict.DType objects; not numpy.dtype. + """ + # Note: array_api_strict.DType objects are not singletons + assert finfo(float64).dtype == float64 + assert iinfo(int8).dtype == int8 + + +@pytest.mark.parametrize("func,arg", [(finfo, float64), (iinfo, int8)]) +def test_finfo_iinfo_output_assumptions(func, arg): + """There should be no expectation for the output of finfo()/iinfo() + to be comparable, hashable, or writeable. + """ + obj = func(arg) + assert obj != func(arg) # Defaut behaviour for custom classes + with pytest.raises(TypeError): + hash(obj) + with pytest.raises(Exception, match="cannot assign"): + obj.min = 0 From 25cc3d7ff0069b222d228d380e6d95bbf9a5dbcf Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 3 Apr 2025 21:18:43 +1100 Subject: [PATCH 243/252] Fix indexing with integers (#146) reviewed at https://github.com/data-apis/array-api-strict/pull/146 --- array_api_strict/_array_object.py | 2 +- array_api_strict/tests/test_array_object.py | 30 +++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index cb2dd11..483952e 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -722,7 +722,7 @@ def __getitem__( devices = {self.device} if isinstance(key, tuple): devices.update( - [subkey.device for subkey in key if hasattr(subkey, "device")] + [subkey.device for subkey in key if isinstance(subkey, Array)] ) if len(devices) > 1: raise ValueError( diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 51f4f31..c7330d8 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -100,6 +100,36 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[:]) assert_raises(IndexError, lambda: a[idx]) +class DummyIndex: + def __init__(self, x): + self.x = x + def __index__(self): + return self.x + + +@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) +@pytest.mark.parametrize( + "integer_index", + [ + 0, + np.int8(0), + np.uint8(0), + np.int16(0), + np.uint16(0), + np.int32(0), + np.uint32(0), + np.int64(0), + np.uint64(0), + DummyIndex(0), + ], +) +def test_indexing_ints(integer_index, device): + # Ensure indexing with different integer types works on all Devices. + device = None if device is None else Device(device) + + a = arange(5, device=device) + assert a[(integer_index,)] == a[integer_index] == a[0] + @pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) def test_indexing_arrays(device): From 6029e2fbc252a81f85bb218806b765fd32d72f9f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 9 Apr 2025 13:11:12 +0200 Subject: [PATCH 244/252] BUG: do not allow asarray of nested sequences of arrays --- array_api_strict/_creation_functions.py | 4 +++ array_api_strict/tests/test_array_object.py | 4 +-- .../tests/test_creation_functions.py | 26 ++++++++++--------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 3b80b8a..db3897c 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -121,6 +121,10 @@ def asarray( if isinstance(obj, Array): return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device) + elif isinstance(obj, list | tuple): + if any(isinstance(x, Array) for x in obj): + raise TypeError("Nested Arrays are not allowed. Use `stack` instead.") + if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)): # Give a better error message in this case. NumPy would convert this # to an object array. TODO: This won't handle large integers in lists. diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index c7330d8..dbab1af 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -482,12 +482,12 @@ def test_array_conversion(): # __array__, which is only used in asarray when converting lists of # arrays. a = ones((2, 3)) - asarray([a]) + np.asarray(a) for device in ("device1", "device2"): a = ones((2, 3), device=array_api_strict.Device(device)) with pytest.raises(RuntimeError, match="Can not convert array"): - asarray([a]) + np.asarray(a) def test__array__(): # __array__ should work for now diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index fc4e3cb..573fc7f 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -22,7 +22,7 @@ zeros, zeros_like, ) -from .._dtypes import int16, float32, float64 +from .._dtypes import float32, float64 from .._array_object import Array, CPU_DEVICE, Device from .._flags import set_array_api_strict_flags @@ -97,18 +97,20 @@ def test_asarray_copy(): a[0] = 0 assert all(b[0] == 0) + def test_asarray_list_of_lists(): - a = asarray(1, dtype=int16) - b = asarray([1], dtype=int16) - res = asarray([a, a]) - assert res.shape == (2,) - assert res.dtype == int16 - assert all(res == asarray([1, 1])) - - res = asarray([b, b]) - assert res.shape == (2, 1) - assert res.dtype == int16 - assert all(res == asarray([[1], [1]])) + lst = [[1, 2, 3], [4, 5, 6]] + res = asarray(lst) + assert res.shape == (2, 3) + + +def test_asarray_nested_arrays(): + # do not allow arrays in nested sequences + with pytest.raises(TypeError): + asarray([[1, 2, 3], asarray([4, 5, 6])]) + + with pytest.raises(TypeError): + asarray([1, asarray(1)]) def test_asarray_device_inference(): From b0e4df9114ecbdd5c3975401a630727d66aec4b7 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 2 Apr 2025 14:09:33 +0100 Subject: [PATCH 245/252] TST: test binops vs. np.generics --- array_api_strict/_array_object.py | 86 +++++++------- array_api_strict/tests/test_array_object.py | 125 ++++++++++++++++---- 2 files changed, 142 insertions(+), 69 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 483952e..6f2c506 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -233,15 +233,15 @@ def _check_allowed_dtypes( return other - def _check_device(self, other: Array | bool | int | float | complex) -> None: - """Check that other is on a device compatible with the current array""" - if isinstance(other, (bool, int, float, complex)): - return - elif isinstance(other, Array): + def _check_type_device(self, other: Array | bool | int | float | complex) -> None: + """Check that other is either a Python scalar or an array on a device + compatible with the current array. + """ + if isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") - else: - raise TypeError(f"Expected Array | python scalar; got {type(other)}") + elif not isinstance(other, bool | int | float | complex): + raise TypeError(f"Expected Array or Python scalar; got {type(other)}") # Helper function to match the type promotion rules in the spec def _promote_scalar(self, scalar: bool | int | float | complex) -> Array: @@ -542,7 +542,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __add__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__add__") if other is NotImplemented: return other @@ -554,7 +554,7 @@ def __and__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __and__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__and__") if other is NotImplemented: return other @@ -651,7 +651,7 @@ def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # ty """ Performs the operation __eq__. """ - self._check_device(other) + self._check_type_device(other) # Even though "all" dtypes are allowed, we still require them to be # promotable with each other. other = self._check_allowed_dtypes(other, "all", "__eq__") @@ -677,7 +677,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __floordiv__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__") if other is NotImplemented: return other @@ -689,7 +689,7 @@ def __ge__(self, other: Array | int | float, /) -> Array: """ Performs the operation __ge__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__ge__") if other is NotImplemented: return other @@ -741,7 +741,7 @@ def __gt__(self, other: Array | int | float, /) -> Array: """ Performs the operation __gt__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__gt__") if other is NotImplemented: return other @@ -796,7 +796,7 @@ def __le__(self, other: Array | int | float, /) -> Array: """ Performs the operation __le__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__le__") if other is NotImplemented: return other @@ -808,7 +808,7 @@ def __lshift__(self, other: Array | int, /) -> Array: """ Performs the operation __lshift__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__lshift__") if other is NotImplemented: return other @@ -820,7 +820,7 @@ def __lt__(self, other: Array | int | float, /) -> Array: """ Performs the operation __lt__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__lt__") if other is NotImplemented: return other @@ -832,7 +832,7 @@ def __matmul__(self, other: Array, /) -> Array: """ Performs the operation __matmul__. """ - self._check_device(other) + self._check_type_device(other) # matmul is not defined for scalars, but without this, we may get # the wrong error message from asarray. other = self._check_allowed_dtypes(other, "numeric", "__matmul__") @@ -845,7 +845,7 @@ def __mod__(self, other: Array | int | float, /) -> Array: """ Performs the operation __mod__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__mod__") if other is NotImplemented: return other @@ -857,7 +857,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __mul__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__mul__") if other is NotImplemented: return other @@ -869,7 +869,7 @@ def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # ty """ Performs the operation __ne__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "all", "__ne__") if other is NotImplemented: return other @@ -890,7 +890,7 @@ def __or__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __or__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__or__") if other is NotImplemented: return other @@ -913,7 +913,7 @@ def __pow__(self, other: Array | int | float | complex, /) -> Array: """ from ._elementwise_functions import pow # type: ignore[attr-defined] - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__pow__") if other is NotImplemented: return other @@ -925,7 +925,7 @@ def __rshift__(self, other: Array | int, /) -> Array: """ Performs the operation __rshift__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__rshift__") if other is NotImplemented: return other @@ -961,7 +961,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __sub__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__sub__") if other is NotImplemented: return other @@ -975,7 +975,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __truediv__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "floating-point", "__truediv__") if other is NotImplemented: return other @@ -987,7 +987,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __xor__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__") if other is NotImplemented: return other @@ -999,7 +999,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __iadd__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__iadd__") if other is NotImplemented: return other @@ -1010,7 +1010,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array: """ Performs the operation __radd__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "numeric", "__radd__") if other is NotImplemented: return other @@ -1022,7 +1022,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __iand__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__") if other is NotImplemented: return other @@ -1033,7 +1033,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __rand__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__") if other is NotImplemented: return other @@ -1045,7 +1045,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __ifloordiv__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__") if other is NotImplemented: return other @@ -1056,7 +1056,7 @@ def __rfloordiv__(self, other: Array | int | float, /) -> Array: """ Performs the operation __rfloordiv__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__") if other is NotImplemented: return other @@ -1068,7 +1068,7 @@ def __ilshift__(self, other: Array | int, /) -> Array: """ Performs the operation __ilshift__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__ilshift__") if other is NotImplemented: return other @@ -1079,7 +1079,7 @@ def __rlshift__(self, other: Array | int, /) -> Array: """ Performs the operation __rlshift__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer", "__rlshift__") if other is NotImplemented: return other @@ -1096,7 +1096,7 @@ def __imatmul__(self, other: Array, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__imatmul__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) res = self._array.__imatmul__(other._array) return self.__class__._new(res, device=self.device) @@ -1109,7 +1109,7 @@ def __rmatmul__(self, other: Array, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) res = self._array.__rmatmul__(other._array) return self.__class__._new(res, device=self.device) @@ -1130,7 +1130,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array: other = self._check_allowed_dtypes(other, "real numeric", "__rmod__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rmod__(other._array) return self.__class__._new(res, device=self.device) @@ -1152,7 +1152,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rmul__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rmul__(other._array) return self.__class__._new(res, device=self.device) @@ -1171,7 +1171,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array: """ Performs the operation __ror__. """ - self._check_device(other) + self._check_type_device(other) other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__") if other is NotImplemented: return other @@ -1219,7 +1219,7 @@ def __rrshift__(self, other: Array | int, /) -> Array: other = self._check_allowed_dtypes(other, "integer", "__rrshift__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rrshift__(other._array) return self.__class__._new(res, device=self.device) @@ -1241,7 +1241,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array: other = self._check_allowed_dtypes(other, "numeric", "__rsub__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rsub__(other._array) return self.__class__._new(res, device=self.device) @@ -1263,7 +1263,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array: other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rtruediv__(other._array) return self.__class__._new(res, device=self.device) @@ -1285,7 +1285,7 @@ def __rxor__(self, other: Array | bool | int, /) -> Array: other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__") if other is NotImplemented: return other - self._check_device(other) + self._check_type_device(other) self, other = self._normalize_two_args(self, other) res = self._array.__rxor__(other._array) return self.__class__._new(res, device=self.device) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index dbab1af..91f3838 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -255,30 +255,37 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT): func(s) return False +binary_op_dtypes = { + "__add__": "numeric", + "__and__": "integer or boolean", + "__eq__": "all", + "__floordiv__": "real numeric", + "__ge__": "real numeric", + "__gt__": "real numeric", + "__le__": "real numeric", + "__lshift__": "integer", + "__lt__": "real numeric", + "__mod__": "real numeric", + "__mul__": "numeric", + "__ne__": "all", + "__or__": "integer or boolean", + "__pow__": "numeric", + "__rshift__": "integer", + "__sub__": "numeric", + "__truediv__": "floating-point", + "__xor__": "integer or boolean", +} +unary_op_dtypes = { + "__abs__": "numeric", + "__invert__": "integer or boolean", + "__neg__": "numeric", + "__pos__": "numeric", +} def test_operators(): # For every operator, we test that it works for the required type # combinations and raises TypeError otherwise - binary_op_dtypes = { - "__add__": "numeric", - "__and__": "integer or boolean", - "__eq__": "all", - "__floordiv__": "real numeric", - "__ge__": "real numeric", - "__gt__": "real numeric", - "__le__": "real numeric", - "__lshift__": "integer", - "__lt__": "real numeric", - "__mod__": "real numeric", - "__mul__": "numeric", - "__ne__": "all", - "__or__": "integer or boolean", - "__pow__": "numeric", - "__rshift__": "integer", - "__sub__": "numeric", - "__truediv__": "floating-point", - "__xor__": "integer or boolean", - } + # Recompute each time because of in-place ops def _array_vals(): for d in _integer_dtypes: @@ -337,12 +344,6 @@ def _array_vals(): else: assert_raises(TypeError, lambda: getattr(x, _op)(y)) - unary_op_dtypes = { - "__abs__": "numeric", - "__invert__": "integer or boolean", - "__neg__": "numeric", - "__pos__": "numeric", - } for op, dtypes in unary_op_dtypes.items(): for a in _array_vals(): if ( @@ -410,6 +411,78 @@ def _matmul_array_vals(): x.__imatmul__(y) +@pytest.mark.parametrize( + "op", + [ + op for op, dtypes in binary_op_dtypes.items() + if dtypes not in ("real numeric", "floating-point") + ], +) +def test_binary_operators_vs_numpy_int(op): + """np.int64 is not a subclass of int and must be disallowed""" + a = asarray(1) + i64 = np.int64(1) + with pytest.raises(TypeError, match="Expected Array or Python scalar"): + getattr(a, op)(i64) + + +@pytest.mark.parametrize( + "op", + [ + op for op, dtypes in binary_op_dtypes.items() + if dtypes not in ("integer", "integer or boolean") + ], +) +def test_binary_operators_vs_numpy_float(op): + """ + np.float64 is a subclass of float and must be allowed. + np.float32 is not and must be rejected. + """ + a = asarray(1.) + f64 = np.float64(1.) + f32 = np.float32(1.) + func = getattr(a, op) + for op in binary_op_dtypes: + assert isinstance(func(f64), Array) + with pytest.raises(TypeError, match="Expected Array or Python scalar"): + func(f32) + + +@pytest.mark.parametrize( + "op", + [ + op for op, dtypes in binary_op_dtypes.items() + if dtypes not in ("integer", "integer or boolean", "real numeric") + ], +) +def test_binary_operators_vs_numpy_complex(op): + """ + np.complex128 is a subclass of complex and must be allowed. + np.complex64 is not and must be rejected. + """ + a = asarray(1.) + c64 = np.complex64(1.) + c128 = np.complex128(1.) + func = getattr(a, op) + for op in binary_op_dtypes: + assert isinstance(func(c128), Array) + with pytest.raises(TypeError, match="Expected Array or Python scalar"): + func(c64) + + +@pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items()) +def test_binary_operators_device_mismatch(op, dtypes): + if dtypes in ("real numeric", "floating-point"): + dtype = float64 + else: + dtype = int64 + + a = asarray(1, dtype=dtype, device=CPU_DEVICE) + b = asarray(1, dtype=dtype, device=Device("device1")) + with pytest.raises(ValueError, match="different devices"): + getattr(a, op)(b) + + def test_python_scalar_construtors(): b = asarray(False) i = asarray(0) From e7fcd348a454e95b354de3d36573668296fac3d2 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 10:31:01 +0100 Subject: [PATCH 246/252] Disallow float64 and complex128 --- array_api_strict/_array_object.py | 3 +- array_api_strict/tests/test_array_object.py | 97 ++++++++------------- 2 files changed, 38 insertions(+), 62 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 6f2c506..579da90 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -240,7 +240,8 @@ def _check_type_device(self, other: Array | bool | int | float | complex) -> Non if isinstance(other, Array): if self.device != other.device: raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.") - elif not isinstance(other, bool | int | float | complex): + # Disallow subclasses of Python scalars, such as np.float64 and np.complex128 + elif type(other) not in (bool, int, float, complex): raise TypeError(f"Expected Array or Python scalar; got {type(other)}") # Helper function to match the type promotion rules in the spec diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 91f3838..e950be5 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -411,72 +411,47 @@ def _matmul_array_vals(): x.__imatmul__(y) -@pytest.mark.parametrize( - "op", - [ - op for op, dtypes in binary_op_dtypes.items() - if dtypes not in ("real numeric", "floating-point") - ], -) -def test_binary_operators_vs_numpy_int(op): - """np.int64 is not a subclass of int and must be disallowed""" - a = asarray(1) - i64 = np.int64(1) - with pytest.raises(TypeError, match="Expected Array or Python scalar"): - getattr(a, op)(i64) - - -@pytest.mark.parametrize( - "op", - [ - op for op, dtypes in binary_op_dtypes.items() - if dtypes not in ("integer", "integer or boolean") - ], -) -def test_binary_operators_vs_numpy_float(op): - """ - np.float64 is a subclass of float and must be allowed. - np.float32 is not and must be rejected. - """ - a = asarray(1.) - f64 = np.float64(1.) - f32 = np.float32(1.) - func = getattr(a, op) - for op in binary_op_dtypes: - assert isinstance(func(f64), Array) - with pytest.raises(TypeError, match="Expected Array or Python scalar"): - func(f32) - - -@pytest.mark.parametrize( - "op", - [ - op for op, dtypes in binary_op_dtypes.items() - if dtypes not in ("integer", "integer or boolean", "real numeric") - ], -) -def test_binary_operators_vs_numpy_complex(op): - """ - np.complex128 is a subclass of complex and must be allowed. - np.complex64 is not and must be rejected. +@pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items()) +def test_binary_operators_vs_numpy_generics(op, dtypes): + """Test that np.bool_, np.int64, np.float32, np.float64, np.complex64, np.complex128 + are disallowed in binary operators. + np.float64 and np.complex128 are subclasses of float and complex, so they need + special treatment in order to be rejected. """ - a = asarray(1.) - c64 = np.complex64(1.) - c128 = np.complex128(1.) - func = getattr(a, op) - for op in binary_op_dtypes: - assert isinstance(func(c128), Array) - with pytest.raises(TypeError, match="Expected Array or Python scalar"): - func(c64) + match = "Expected Array or Python scalar" + + if dtypes not in ("numeric", "integer", "real numeric", "floating-point"): + a = asarray(True) + func = getattr(a, op) + with pytest.raises(TypeError, match=match): + func(np.bool_(True)) + + if dtypes != "floating-point": + a = asarray(1) + func = getattr(a, op) + with pytest.raises(TypeError, match=match): + func(np.int64(1)) + + if dtypes not in ("integer", "integer or boolean"): + a = asarray(1.,) + func = getattr(a, op) + with pytest.raises(TypeError, match=match): + func(np.float32(1.)) + with pytest.raises(TypeError, match=match): + func(np.float64(1.)) + + if dtypes not in ("integer", "integer or boolean", "real numeric"): + a = asarray(1.,) + func = getattr(a, op) + with pytest.raises(TypeError, match=match): + func(np.complex64(1.)) + with pytest.raises(TypeError, match=match): + func(np.complex128(1.)) @pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items()) def test_binary_operators_device_mismatch(op, dtypes): - if dtypes in ("real numeric", "floating-point"): - dtype = float64 - else: - dtype = int64 - + dtype = float64 if dtypes == "floating-point" else int64 a = asarray(1, dtype=dtype, device=CPU_DEVICE) b = asarray(1, dtype=dtype, device=Device("device1")) with pytest.raises(ValueError, match="different devices"): From d213304dc8356e3b7dcbcf7a959edb2b4863c81d Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 8 May 2025 20:05:03 +0200 Subject: [PATCH 247/252] BUG: roll does not accept array shifts --- array_api_strict/_manipulation_functions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index fe4a608..7c4adda 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -138,6 +138,10 @@ def roll( See its docstring for more information. """ + if not isinstance(shift, int | tuple): + raise ValueError( + f"`shift` can only be an int or a tuple, got {type(shift)=} instead." + ) return Array._new(np.roll(x._array, shift, axis=axis), device=x.device) From bf0c944f229d6aee483e15cecc47168b904c675c Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Sat, 10 May 2025 23:45:41 +0200 Subject: [PATCH 248/252] TYP: Compact Python scalar types (#149) * TYP: Compact Python scalar types * Update array_api_strict/_array_object.py --- array_api_strict/_array_object.py | 80 +++++++++++----------- array_api_strict/_creation_functions.py | 22 +++--- array_api_strict/_data_type_functions.py | 4 +- array_api_strict/_elementwise_functions.py | 14 ++-- array_api_strict/_helpers.py | 4 +- array_api_strict/_linalg.py | 2 +- array_api_strict/_searching_functions.py | 7 +- array_api_strict/_statistical_functions.py | 4 +- 8 files changed, 62 insertions(+), 75 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 579da90..7242055 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -191,7 +191,7 @@ def __array__( # NumPy behavior def _check_allowed_dtypes( - self, other: Array | bool | int | float | complex, dtype_category: str, op: str + self, other: Array | complex, dtype_category: str, op: str ) -> Array: """ Helper function for operators to only allow specific input dtypes @@ -233,7 +233,7 @@ def _check_allowed_dtypes( return other - def _check_type_device(self, other: Array | bool | int | float | complex) -> None: + def _check_type_device(self, other: Array | complex) -> None: """Check that other is either a Python scalar or an array on a device compatible with the current array. """ @@ -245,7 +245,7 @@ def _check_type_device(self, other: Array | bool | int | float | complex) -> Non raise TypeError(f"Expected Array or Python scalar; got {type(other)}") # Helper function to match the type promotion rules in the spec - def _promote_scalar(self, scalar: bool | int | float | complex) -> Array: + def _promote_scalar(self, scalar: complex) -> Array: """ Returns a promoted version of a Python scalar appropriate for use with operations on self. @@ -539,7 +539,7 @@ def __abs__(self) -> Array: res = self._array.__abs__() return self.__class__._new(res, device=self.device) - def __add__(self, other: Array | int | float | complex, /) -> Array: + def __add__(self, other: Array | complex, /) -> Array: """ Performs the operation __add__. """ @@ -551,7 +551,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__add__(other._array) return self.__class__._new(res, device=self.device) - def __and__(self, other: Array | bool | int, /) -> Array: + def __and__(self, other: Array | int, /) -> Array: """ Performs the operation __and__. """ @@ -648,7 +648,7 @@ def __dlpack_device__(self) -> tuple[IntEnum, int]: # Note: device support is required for this return self._array.__dlpack_device__() - def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override] + def __eq__(self, other: Array | complex, /) -> Array: # type: ignore[override] """ Performs the operation __eq__. """ @@ -674,7 +674,7 @@ def __float__(self) -> float: res = self._array.__float__() return res - def __floordiv__(self, other: Array | int | float, /) -> Array: + def __floordiv__(self, other: Array | float, /) -> Array: """ Performs the operation __floordiv__. """ @@ -686,7 +686,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array: res = self._array.__floordiv__(other._array) return self.__class__._new(res, device=self.device) - def __ge__(self, other: Array | int | float, /) -> Array: + def __ge__(self, other: Array | float, /) -> Array: """ Performs the operation __ge__. """ @@ -738,7 +738,7 @@ def __getitem__( res = self._array.__getitem__(np_key) return self._new(res, device=self.device) - def __gt__(self, other: Array | int | float, /) -> Array: + def __gt__(self, other: Array | float, /) -> Array: """ Performs the operation __gt__. """ @@ -793,7 +793,7 @@ def __iter__(self) -> Iterator[Array]: # implemented, which implies iteration on 1-D arrays. return (Array._new(i, device=self.device) for i in self._array) - def __le__(self, other: Array | int | float, /) -> Array: + def __le__(self, other: Array | float, /) -> Array: """ Performs the operation __le__. """ @@ -817,7 +817,7 @@ def __lshift__(self, other: Array | int, /) -> Array: res = self._array.__lshift__(other._array) return self.__class__._new(res, device=self.device) - def __lt__(self, other: Array | int | float, /) -> Array: + def __lt__(self, other: Array | float, /) -> Array: """ Performs the operation __lt__. """ @@ -842,7 +842,7 @@ def __matmul__(self, other: Array, /) -> Array: res = self._array.__matmul__(other._array) return self.__class__._new(res, device=self.device) - def __mod__(self, other: Array | int | float, /) -> Array: + def __mod__(self, other: Array | float, /) -> Array: """ Performs the operation __mod__. """ @@ -854,7 +854,7 @@ def __mod__(self, other: Array | int | float, /) -> Array: res = self._array.__mod__(other._array) return self.__class__._new(res, device=self.device) - def __mul__(self, other: Array | int | float | complex, /) -> Array: + def __mul__(self, other: Array | complex, /) -> Array: """ Performs the operation __mul__. """ @@ -866,7 +866,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__mul__(other._array) return self.__class__._new(res, device=self.device) - def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override] + def __ne__(self, other: Array | complex, /) -> Array: # type: ignore[override] """ Performs the operation __ne__. """ @@ -887,7 +887,7 @@ def __neg__(self) -> Array: res = self._array.__neg__() return self.__class__._new(res, device=self.device) - def __or__(self, other: Array | bool | int, /) -> Array: + def __or__(self, other: Array | int, /) -> Array: """ Performs the operation __or__. """ @@ -908,7 +908,7 @@ def __pos__(self) -> Array: res = self._array.__pos__() return self.__class__._new(res, device=self.device) - def __pow__(self, other: Array | int | float | complex, /) -> Array: + def __pow__(self, other: Array | complex, /) -> Array: """ Performs the operation __pow__. """ @@ -945,7 +945,7 @@ def __setitem__( | Array | tuple[int | slice | EllipsisType, ...] ), - value: Array | bool | int | float | complex, + value: Array | complex, /, ) -> None: """ @@ -958,7 +958,7 @@ def __setitem__( np_key = key._array if isinstance(key, Array) else key self._array.__setitem__(np_key, asarray(value)._array) - def __sub__(self, other: Array | int | float | complex, /) -> Array: + def __sub__(self, other: Array | complex, /) -> Array: """ Performs the operation __sub__. """ @@ -972,7 +972,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array: # PEP 484 requires int to be a subtype of float, but __truediv__ should # not accept int. - def __truediv__(self, other: Array | int | float | complex, /) -> Array: + def __truediv__(self, other: Array | complex, /) -> Array: """ Performs the operation __truediv__. """ @@ -984,7 +984,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__truediv__(other._array) return self.__class__._new(res, device=self.device) - def __xor__(self, other: Array | bool | int, /) -> Array: + def __xor__(self, other: Array | int, /) -> Array: """ Performs the operation __xor__. """ @@ -996,7 +996,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array: res = self._array.__xor__(other._array) return self.__class__._new(res, device=self.device) - def __iadd__(self, other: Array | int | float | complex, /) -> Array: + def __iadd__(self, other: Array | complex, /) -> Array: """ Performs the operation __iadd__. """ @@ -1007,7 +1007,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array: self._array.__iadd__(other._array) return self - def __radd__(self, other: Array | int | float | complex, /) -> Array: + def __radd__(self, other: Array | complex, /) -> Array: """ Performs the operation __radd__. """ @@ -1019,7 +1019,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__radd__(other._array) return self.__class__._new(res, device=self.device) - def __iand__(self, other: Array | bool | int, /) -> Array: + def __iand__(self, other: Array | int, /) -> Array: """ Performs the operation __iand__. """ @@ -1030,7 +1030,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array: self._array.__iand__(other._array) return self - def __rand__(self, other: Array | bool | int, /) -> Array: + def __rand__(self, other: Array | int, /) -> Array: """ Performs the operation __rand__. """ @@ -1042,7 +1042,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array: res = self._array.__rand__(other._array) return self.__class__._new(res, device=self.device) - def __ifloordiv__(self, other: Array | int | float, /) -> Array: + def __ifloordiv__(self, other: Array | float, /) -> Array: """ Performs the operation __ifloordiv__. """ @@ -1053,7 +1053,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array: self._array.__ifloordiv__(other._array) return self - def __rfloordiv__(self, other: Array | int | float, /) -> Array: + def __rfloordiv__(self, other: Array | float, /) -> Array: """ Performs the operation __rfloordiv__. """ @@ -1114,7 +1114,7 @@ def __rmatmul__(self, other: Array, /) -> Array: res = self._array.__rmatmul__(other._array) return self.__class__._new(res, device=self.device) - def __imod__(self, other: Array | int | float, /) -> Array: + def __imod__(self, other: Array | float, /) -> Array: """ Performs the operation __imod__. """ @@ -1124,7 +1124,7 @@ def __imod__(self, other: Array | int | float, /) -> Array: self._array.__imod__(other._array) return self - def __rmod__(self, other: Array | int | float, /) -> Array: + def __rmod__(self, other: Array | float, /) -> Array: """ Performs the operation __rmod__. """ @@ -1136,7 +1136,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array: res = self._array.__rmod__(other._array) return self.__class__._new(res, device=self.device) - def __imul__(self, other: Array | int | float | complex, /) -> Array: + def __imul__(self, other: Array | complex, /) -> Array: """ Performs the operation __imul__. """ @@ -1146,7 +1146,7 @@ def __imul__(self, other: Array | int | float | complex, /) -> Array: self._array.__imul__(other._array) return self - def __rmul__(self, other: Array | int | float | complex, /) -> Array: + def __rmul__(self, other: Array | complex, /) -> Array: """ Performs the operation __rmul__. """ @@ -1158,7 +1158,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__rmul__(other._array) return self.__class__._new(res, device=self.device) - def __ior__(self, other: Array | bool | int, /) -> Array: + def __ior__(self, other: Array | int, /) -> Array: """ Performs the operation __ior__. """ @@ -1168,7 +1168,7 @@ def __ior__(self, other: Array | bool | int, /) -> Array: self._array.__ior__(other._array) return self - def __ror__(self, other: Array | bool | int, /) -> Array: + def __ror__(self, other: Array | int, /) -> Array: """ Performs the operation __ror__. """ @@ -1180,7 +1180,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array: res = self._array.__ror__(other._array) return self.__class__._new(res, device=self.device) - def __ipow__(self, other: Array | int | float | complex, /) -> Array: + def __ipow__(self, other: Array | complex, /) -> Array: """ Performs the operation __ipow__. """ @@ -1190,7 +1190,7 @@ def __ipow__(self, other: Array | int | float | complex, /) -> Array: self._array.__ipow__(other._array) return self - def __rpow__(self, other: Array | int | float | complex, /) -> Array: + def __rpow__(self, other: Array | complex, /) -> Array: """ Performs the operation __rpow__. """ @@ -1225,7 +1225,7 @@ def __rrshift__(self, other: Array | int, /) -> Array: res = self._array.__rrshift__(other._array) return self.__class__._new(res, device=self.device) - def __isub__(self, other: Array | int | float | complex, /) -> Array: + def __isub__(self, other: Array | complex, /) -> Array: """ Performs the operation __isub__. """ @@ -1235,7 +1235,7 @@ def __isub__(self, other: Array | int | float | complex, /) -> Array: self._array.__isub__(other._array) return self - def __rsub__(self, other: Array | int | float | complex, /) -> Array: + def __rsub__(self, other: Array | complex, /) -> Array: """ Performs the operation __rsub__. """ @@ -1247,7 +1247,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__rsub__(other._array) return self.__class__._new(res, device=self.device) - def __itruediv__(self, other: Array | int | float | complex, /) -> Array: + def __itruediv__(self, other: Array | complex, /) -> Array: """ Performs the operation __itruediv__. """ @@ -1257,7 +1257,7 @@ def __itruediv__(self, other: Array | int | float | complex, /) -> Array: self._array.__itruediv__(other._array) return self - def __rtruediv__(self, other: Array | int | float | complex, /) -> Array: + def __rtruediv__(self, other: Array | complex, /) -> Array: """ Performs the operation __rtruediv__. """ @@ -1269,7 +1269,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__rtruediv__(other._array) return self.__class__._new(res, device=self.device) - def __ixor__(self, other: Array | bool | int, /) -> Array: + def __ixor__(self, other: Array | int, /) -> Array: """ Performs the operation __ixor__. """ @@ -1279,7 +1279,7 @@ def __ixor__(self, other: Array | bool | int, /) -> Array: self._array.__ixor__(other._array) return self - def __rxor__(self, other: Array | bool | int, /) -> Array: + def __rxor__(self, other: Array | int, /) -> Array: """ Performs the operation __rxor__. """ diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index db3897c..69d37aa 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -68,13 +68,7 @@ def _check_device(device: Device | None) -> None: def asarray( - obj: Array - | bool - | int - | float - | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol, + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: DType | None = None, @@ -135,10 +129,10 @@ def asarray( def arange( - start: int | float, + start: float, /, - stop: int | float | None = None, - step: int | float = 1, + stop: float | None = None, + step: float = 1, *, dtype: DType | None = None, device: Device | None = None, @@ -248,7 +242,7 @@ def from_dlpack( def full( shape: int | tuple[int, ...], - fill_value: bool | int | float | complex, + fill_value: complex, *, dtype: DType | None = None, device: Device | None = None, @@ -276,7 +270,7 @@ def full( def full_like( x: Array, /, - fill_value: bool | int | float | complex, + fill_value: complex, *, dtype: DType | None = None, device: Device | None = None, @@ -302,8 +296,8 @@ def full_like( def linspace( - start: int | float | complex, - stop: int | float | complex, + start: complex, + stop: complex, /, num: int, *, diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index e318724..82d438f 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -210,9 +210,7 @@ def isdtype(dtype: DType, kind: DType | str | tuple[DType | str, ...]) -> bool: raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}") -def result_type( - *arrays_and_dtypes: DType | Array | bool | int | float | complex, -) -> DType: +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: """ Array API compatible wrapper for :py:func:`np.result_type `. diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index b05e0fd..8cd42cf 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -51,15 +51,15 @@ def inner(x1, x2, /) -> Array: # static type annotation for ArrayOrPythonScalar arguments given a category # NB: keep the keys in sync with the _dtype_categories dict _annotations = { - "all": "bool | int | float | complex | Array", - "real numeric": "int | float | Array", - "numeric": "int | float | complex | Array", + "all": "complex | Array", + "real numeric": "float | Array", + "numeric": "complex | Array", "integer": "int | Array", - "integer or boolean": "bool | int | Array", + "integer or boolean": "int | Array", "boolean": "bool | Array", "real floating-point": "float | Array", "complex floating-point": "complex | Array", - "floating-point": "float | complex | Array", + "floating-point": "complex | Array", } @@ -268,8 +268,8 @@ def ceil(x: Array, /) -> Array: def clip( x: Array, /, - min: Array | int | float | None = None, - max: Array | int | float | None = None, + min: Array | float | None = None, + max: Array | float | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.clip `. diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py index e8c6767..db58667 100644 --- a/array_api_strict/_helpers.py +++ b/array_api_strict/_helpers.py @@ -8,8 +8,8 @@ def _maybe_normalize_py_scalars( - x1: Array | bool | int | float | complex, - x2: Array | bool | int | float | complex, + x1: Array | complex, + x2: Array | complex, dtype_category: str, func_name: str, ) -> tuple[Array, Array]: diff --git a/array_api_strict/_linalg.py b/array_api_strict/_linalg.py index 72d7f0a..84a31f6 100644 --- a/array_api_strict/_linalg.py +++ b/array_api_strict/_linalg.py @@ -415,7 +415,7 @@ def vector_norm( *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: int | float = 2, + ord: float = 2, ) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm `. diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index c42ccc7..3fbb0cf 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -92,12 +92,7 @@ def searchsorted( ) -def where( - condition: Array, - x1: Array | bool | int | float | complex, - x2: Array | bool | int | float | complex, - /, -) -> Array: +def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Array: """ Array API compatible wrapper for :py:func:`np.where `. diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 4160f7a..35876dd 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -149,7 +149,7 @@ def std( /, *, axis: int | tuple[int, ...] | None = None, - correction: int | float = 0.0, + correction: float = 0.0, keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here @@ -181,7 +181,7 @@ def var( /, *, axis: int | tuple[int, ...] | None = None, - correction: int | float = 0.0, + correction: float = 0.0, keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here From e25b9281fd9b94403ffda337d3d0240bb155f530 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Sat, 10 May 2025 23:48:01 +0200 Subject: [PATCH 249/252] ENH: unary functions overhaul; better input validation (#148) * ENH: unary functions overhaul * trivial fix * lint * reduce diff size * alphabetical order * type annotations * Code review --- array_api_strict/_elementwise_functions.py | 567 ++++-------------- array_api_strict/_helpers.py | 18 +- array_api_strict/_searching_functions.py | 3 + array_api_strict/tests/test_array_object.py | 10 +- .../tests/test_elementwise_functions.py | 139 +++-- .../tests/test_searching_functions.py | 58 ++ 6 files changed, 285 insertions(+), 510 deletions(-) diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 8cd42cf..d86bc27 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -1,15 +1,16 @@ +from collections.abc import Callable +from functools import wraps +from types import NoneType + import numpy as np from ._array_object import Array from ._creation_functions import asarray from ._data_type_functions import broadcast_to, iinfo from ._dtypes import ( - _boolean_dtypes, _complex_floating_dtypes, _dtype_categories, - _floating_dtypes, _integer_dtypes, - _integer_or_boolean_dtypes, _numeric_dtypes, _real_floating_dtypes, _real_numeric_dtypes, @@ -35,7 +36,7 @@ def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func): return Array._new(np_func(x1._array, x2._array), device=x1.device) -_binary_docstring_template = """ +_docstring_template = """ Array API compatible wrapper for :py:func:`np.%s `. See its docstring for more information. @@ -117,7 +118,7 @@ def inner(x1, x2, /) -> Array: func = _create_binary_func(func_name, dtype_category, np_func) func.__name__ = func_name - func.__doc__ = _binary_docstring_template % (numpy_name, numpy_name) + func.__doc__ = _docstring_template % (numpy_name, numpy_name) func.__annotations__['x1'] = _annotations[dtype_category] func.__annotations__['x2'] = _annotations[dtype_category] @@ -153,115 +154,86 @@ def bitwise_right_shift(x1: int | Array, x2: int | Array, /) -> Array: del func, _create_binary_func -def abs(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.abs `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in abs") - return Array._new(np.abs(x._array), device=x.device) - - -# Note: the function name is different here -def acos(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arccos `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in acos") - return Array._new(np.arccos(x._array), device=x.device) - - -# Note: the function name is different here -def acosh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arccosh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in acosh") - return Array._new(np.arccosh(x._array), device=x.device) - - -# Note: the function name is different here -def asin(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arcsin `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in asin") - return Array._new(np.arcsin(x._array), device=x.device) - - -# Note: the function name is different here -def asinh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arcsinh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in asinh") - return Array._new(np.arcsinh(x._array), device=x.device) - - -# Note: the function name is different here -def atan(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arctan `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in atan") - return Array._new(np.arctan(x._array), device=x.device) - - -# Note: the function name is different here -def atanh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.arctanh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in atanh") - return Array._new(np.arctanh(x._array), device=x.device) - - -# Note: the function name is different here -def bitwise_invert(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.invert `. - - See its docstring for more information. - """ - if x.dtype not in _integer_or_boolean_dtypes: - raise TypeError("Only integer or boolean dtypes are allowed in bitwise_invert") - return Array._new(np.invert(x._array), device=x.device) - +def _create_unary_func( + func_name: str, + dtype_category: str, + np_func_name: str | None = None, +) -> Callable[[Array], Array]: + allowed_dtypes = _dtype_categories[dtype_category] + np_func_name = np_func_name or func_name + np_func = getattr(np, np_func_name) + + def func(x: Array, /) -> Array: + if not isinstance(x, Array): + raise TypeError(f"Only Array objects are allowed; got {type(x)}") + if x.dtype not in allowed_dtypes: + raise TypeError( + f"Only {dtype_category} dtypes are allowed in {func_name}; " + f"got {x.dtype}." + ) + return Array._new(np_func(x._array), device=x.device) -def ceil(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.ceil `. - - See its docstring for more information. - """ - if x.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in ceil") - if x.dtype in _integer_dtypes: - # Note: The return dtype of ceil is the same as the input - return x - return Array._new(np.ceil(x._array), device=x.device) + func.__name__ = func_name + func.__doc__ = _docstring_template % (np_func_name, np_func_name) + return func + + +def _identity_if_integer(func: Callable[[Array], Array]) -> Callable[[Array], Array]: + """Hack around NumPy 1.x behaviour for ceil, floor, and trunc + vs. integer inputs + """ + + @wraps(func) + def wrapper(x: Array, /) -> Array: + if isinstance(x, Array) and x.dtype in _integer_dtypes: + return x + return func(x) + + return wrapper + + +abs = _create_unary_func("abs", "numeric") +acos = _create_unary_func("acos", "floating-point", "arccos") +acosh = _create_unary_func("acosh", "floating-point", "arccosh") +asin = _create_unary_func("asin", "floating-point", "arcsin") +asinh = _create_unary_func("asinh", "floating-point", "arcsinh") +atan = _create_unary_func("atan", "floating-point", "arctan") +atanh = _create_unary_func("atanh", "floating-point", "arctanh") +bitwise_invert = _create_unary_func("bitwise_invert", "integer or boolean", "invert") +ceil = _identity_if_integer(_create_unary_func("ceil", "real numeric")) +conj = _create_unary_func("conj", "numeric") +cos = _create_unary_func("cos", "floating-point", "cos") +cosh = _create_unary_func("cosh", "floating-point", "cosh") +exp = _create_unary_func("exp", "floating-point") +expm1 = _create_unary_func("expm1", "floating-point") +floor = _identity_if_integer(_create_unary_func("floor", "real numeric")) +imag = _create_unary_func("imag", "complex floating-point") +isfinite = _create_unary_func("isfinite", "numeric") +isinf = _create_unary_func("isinf", "numeric") +isnan = _create_unary_func("isnan", "numeric") +log = _create_unary_func("log", "floating-point") +log10 = _create_unary_func("log10", "floating-point") +log1p = _create_unary_func("log1p", "floating-point") +log2 = _create_unary_func("log2", "floating-point") +logical_not = _create_unary_func("logical_not", "boolean") +negative = _create_unary_func("negative", "numeric") +positive = _create_unary_func("positive", "numeric") +real = _create_unary_func("real", "numeric") +reciprocal = requires_api_version("2024.12")( + _create_unary_func("reciprocal", "floating-point") +) +round = _create_unary_func("round", "numeric") +signbit = requires_api_version("2023.12")( + _create_unary_func("signbit", "real floating-point") +) +sin = _create_unary_func("sin", "floating-point") +sinh = _create_unary_func("sinh", "floating-point") +sqrt = _create_unary_func("sqrt", "floating-point") +square = _create_unary_func("square", "numeric") +tan = _create_unary_func("tan", "floating-point") +tanh = _create_unary_func("tanh", "floating-point") +trunc = _identity_if_integer(_create_unary_func("trunc", "real numeric")) -# WARNING: This function is not yet tested by the array-api-tests test suite. # Note: min and max argument names are different and not optional in numpy. @requires_api_version('2023.12') @@ -276,42 +248,40 @@ def clip( See its docstring for more information. """ - if isinstance(min, Array) and x.device != min.device: - raise ValueError(f"Arrays from two different devices ({x.device} and {min.device}) can not be combined.") - if isinstance(max, Array) and x.device != max.device: - raise ValueError(f"Arrays from two different devices ({x.device} and {max.device}) can not be combined.") + if not isinstance(x, Array): + raise TypeError(f"Only Array objects are allowed; got {type(x)}") if (x.dtype not in _real_numeric_dtypes or isinstance(min, Array) and min.dtype not in _real_numeric_dtypes or isinstance(max, Array) and max.dtype not in _real_numeric_dtypes): raise TypeError("Only real numeric dtypes are allowed in clip") - if not isinstance(min, (int, float, Array, type(None))): - raise TypeError("min must be an None, int, float, or an array") - if not isinstance(max, (int, float, Array, type(None))): - raise TypeError("max must be an None, int, float, or an array") - - # Mixed dtype kinds is implementation defined - if (x.dtype in _integer_dtypes - and (isinstance(min, float) or - isinstance(min, Array) and min.dtype in _real_floating_dtypes)): - raise TypeError("min must be integral when x is integral") - if (x.dtype in _integer_dtypes - and (isinstance(max, float) or - isinstance(max, Array) and max.dtype in _real_floating_dtypes)): - raise TypeError("max must be integral when x is integral") - if (x.dtype in _real_floating_dtypes - and (isinstance(min, int) or - isinstance(min, Array) and min.dtype in _integer_dtypes)): - raise TypeError("min must be floating-point when x is floating-point") - if (x.dtype in _real_floating_dtypes - and (isinstance(max, int) or - isinstance(max, Array) and max.dtype in _integer_dtypes)): - raise TypeError("max must be floating-point when x is floating-point") if min is max is None: - # Note: NumPy disallows min = max = None return x + for argname, arg in ("min", min), ("max", max): + if isinstance(arg, Array): + if x.device != arg.device: + raise ValueError( + f"Arrays from two different devices ({x.device} and {arg.device}) " + "can not be combined." + ) + # Disallow subclasses of Python scalars, e.g. np.float64 + elif type(arg) not in (int, float, NoneType): + raise TypeError( + f"{argname} must be None, int, float, or Array; got {type(arg)}" + ) + + # Mixed dtype kinds is implementation defined + if (x.dtype in _integer_dtypes + and (isinstance(arg, float) or + isinstance(arg, Array) and arg.dtype in _real_floating_dtypes)): + raise TypeError(f"{argname} must be integral when x is integral") + if (x.dtype in _real_floating_dtypes + and (isinstance(arg, int) or + isinstance(arg, Array) and arg.dtype in _integer_dtypes)): + raise TypeError(f"{arg} must be floating-point when x is floating-point") + # Normalize to make the below logic simpler if min is not None: min = asarray(min)._array @@ -370,330 +340,17 @@ def _isscalar(a): return Array._new(out, device=device) -def conj(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.conj `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in conj") - return Array._new(np.conj(x._array), device=x.device) - - -def cos(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.cos `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in cos") - return Array._new(np.cos(x._array), device=x.device) - - -def cosh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.cosh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in cosh") - return Array._new(np.cosh(x._array), device=x.device) - - -def exp(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.exp `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in exp") - return Array._new(np.exp(x._array), device=x.device) - - -def expm1(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.expm1 `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in expm1") - return Array._new(np.expm1(x._array), device=x.device) - - -def floor(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.floor `. - - See its docstring for more information. - """ - if x.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in floor") - if x.dtype in _integer_dtypes: - # Note: The return dtype of floor is the same as the input - return x - return Array._new(np.floor(x._array), device=x.device) - - -def imag(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.imag `. - - See its docstring for more information. - """ - if x.dtype not in _complex_floating_dtypes: - raise TypeError("Only complex floating-point dtypes are allowed in imag") - return Array._new(np.imag(x._array), device=x.device) - - -def isfinite(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.isfinite `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in isfinite") - return Array._new(np.isfinite(x._array), device=x.device) - - -def isinf(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.isinf `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in isinf") - return Array._new(np.isinf(x._array), device=x.device) - - -def isnan(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.isnan `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in isnan") - return Array._new(np.isnan(x._array), device=x.device) - - -def log(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.log `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in log") - return Array._new(np.log(x._array), device=x.device) - - -def log1p(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.log1p `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in log1p") - return Array._new(np.log1p(x._array), device=x.device) - - -def log2(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.log2 `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in log2") - return Array._new(np.log2(x._array), device=x.device) - - -def log10(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.log10 `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in log10") - return Array._new(np.log10(x._array), device=x.device) - - -def logical_not(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.logical_not `. - - See its docstring for more information. - """ - if x.dtype not in _boolean_dtypes: - raise TypeError("Only boolean dtypes are allowed in logical_not") - return Array._new(np.logical_not(x._array), device=x.device) - - -def negative(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.negative `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in negative") - return Array._new(np.negative(x._array), device=x.device) - - -def positive(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.positive `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in positive") - return Array._new(np.positive(x._array), device=x.device) - - -def real(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.real `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in real") - return Array._new(np.real(x._array), device=x.device) - - -@requires_api_version('2024.12') -def reciprocal(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.reciprocal `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in reciprocal") - return Array._new(np.reciprocal(x._array), device=x.device) - - -def round(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.round `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in round") - return Array._new(np.round(x._array), device=x.device) - - def sign(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.sign `. See its docstring for more information. """ + if not isinstance(x, Array): + raise TypeError(f"Only Array objects are allowed; got {type(x)}") if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sign") + # Special treatment to work around non-compliant NumPy 1.x behaviour if x.dtype in _complex_floating_dtypes: return x/abs(x) return Array._new(np.sign(x._array), device=x.device) - - -@requires_api_version('2023.12') -def signbit(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.signbit `. - - See its docstring for more information. - """ - if x.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in signbit") - return Array._new(np.signbit(x._array), device=x.device) - - -def sin(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.sin `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in sin") - return Array._new(np.sin(x._array), device=x.device) - - -def sinh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.sinh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in sinh") - return Array._new(np.sinh(x._array), device=x.device) - - -def square(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.square `. - - See its docstring for more information. - """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in square") - return Array._new(np.square(x._array), device=x.device) - - -def sqrt(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.sqrt `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in sqrt") - return Array._new(np.sqrt(x._array), device=x.device) - - -def tan(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.tan `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in tan") - return Array._new(np.tan(x._array), device=x.device) - - -def tanh(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.tanh `. - - See its docstring for more information. - """ - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in tanh") - return Array._new(np.tanh(x._array), device=x.device) - - -def trunc(x: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.trunc `. - - See its docstring for more information. - """ - if x.dtype not in _real_numeric_dtypes: - raise TypeError("Only real numeric dtypes are allowed in trunc") - if x.dtype in _integer_dtypes: - # Note: The return dtype of trunc is the same as the input - return x - return Array._new(np.trunc(x._array), device=x.device) diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py index db58667..35fa7c8 100644 --- a/array_api_strict/_helpers.py +++ b/array_api_strict/_helpers.py @@ -4,7 +4,7 @@ from ._dtypes import _dtype_categories from ._flags import get_array_api_strict_flags -_py_scalars = (bool, int, float, complex) +_PY_SCALARS = (bool, int, float, complex) def _maybe_normalize_py_scalars( @@ -20,20 +20,26 @@ def _maybe_normalize_py_scalars( _allowed_dtypes = _dtype_categories[dtype_category] - if isinstance(x1, _py_scalars): - if isinstance(x2, _py_scalars): + # Disallow subclasses, e.g. np.float64 and np.complex128 + if type(x1) in _PY_SCALARS: + if type(x2) in _PY_SCALARS: raise TypeError(f"Two scalars not allowed, got {type(x1) = } and {type(x2) =}") - # x2 must be an array + if not isinstance(x2, Array): + raise TypeError(f"Argument is neither an Array nor a Python scalar: {type(x2)=} ") if x2.dtype not in _allowed_dtypes: raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x2.dtype}.") x1 = x2._promote_scalar(x1) - elif isinstance(x2, _py_scalars): - # x1 must be an array + elif type(x2) in _PY_SCALARS: + if not isinstance(x1, Array): + raise TypeError(f"Argument is neither an Array nor a Python scalar: {type(x2)=} ") if x1.dtype not in _allowed_dtypes: raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x1.dtype}.") x2 = x1._promote_scalar(x2) else: + if not isinstance(x1, Array) or not isinstance(x2, Array): + raise TypeError(f"Argument(s) are neither Array nor Python scalars: {type(x1)=} and {type(x2)=}") + if x1.dtype not in _allowed_dtypes or x2.dtype not in _allowed_dtypes: raise TypeError(f"Only {dtype_category} dtypes are allowed in {func_name}(...). " f"Got {x1.dtype} and {x2.dtype}.") diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 3fbb0cf..5334e8f 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -98,6 +98,9 @@ def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Arra See its docstring for more information. """ + if not isinstance(condition, Array): + raise TypeError(f"`condition` must be an Array; got {type(condition)}") + x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where") # Call result type here just to raise on disallowed type combinations diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index e950be5..15d88a9 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -412,10 +412,12 @@ def _matmul_array_vals(): @pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items()) -def test_binary_operators_vs_numpy_generics(op, dtypes): - """Test that np.bool_, np.int64, np.float32, np.float64, np.complex64, np.complex128 - are disallowed in binary operators. - np.float64 and np.complex128 are subclasses of float and complex, so they need +def test_binary_operators_numpy_scalars(op, dtypes): + """ + Test that NumPy scalars (np.generic) are explicitly disallowed. + + This must notably include np.float64 and np.complex128, which are + subclasses of float and complex respectively, so they need special treatment in order to be rejected. """ match = "Expected Array or Python scalar" diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 99596b4..0f740d3 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -1,24 +1,26 @@ from inspect import signature, getmodule -from pytest import raises as assert_raises +import numpy as np +import pytest from numpy.testing import suppress_warnings from .. import asarray, _elementwise_functions +from .._array_object import ALL_DEVICES, CPU_DEVICE, Device from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift from .._dtypes import ( _dtype_categories, _boolean_dtypes, _floating_dtypes, _integer_dtypes, + bool as xp_bool, + float64, int8, int16, int32, int64, uint64, ) -from .._flags import set_array_api_strict_flags - from .test_array_object import _check_op_array_scalar, BIG_INT import array_api_strict @@ -104,6 +106,13 @@ def nargs(func): } +elementwise_binary_function_names = [ + func_name + for func_name in elementwise_function_input_types + if nargs(getattr(_elementwise_functions, func_name)) == 2 +] + + def test_nargs(): # Explicitly check number of arguments for a few functions assert nargs(array_api_strict.logaddexp) == 2 @@ -126,33 +135,81 @@ def test_missing_functions(): assert set(mod_funcs) == set(elementwise_function_input_types) -def test_function_device_persists(): - # Test that the device of the input and output array are the same +@pytest.mark.parametrize("device", ALL_DEVICES) +@pytest.mark.parametrize("func_name,types", elementwise_function_input_types.items()) +def test_elementwise_function_device_persists(func_name, types, device): + """Test that the device of the input and output array are the same""" def _array_vals(dtypes): - for d in dtypes: - yield asarray(1., dtype=d) - - # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2024.12") - - for func_name, types in elementwise_function_input_types.items(): - dtypes = _dtype_categories[types] - func = getattr(_elementwise_functions, func_name) + for dtype in dtypes: + yield asarray(1., dtype=dtype, device=device) + + dtypes = _dtype_categories[types] + func = getattr(_elementwise_functions, func_name) + + for x in _array_vals(dtypes): + if nargs(func) == 2: + # This way we don't have to deal with incompatible + # types of the two arguments. + r = func(x, x) + assert r.device == x.device + + else: + # `atanh` needs a slightly different input value from + # everyone else + if func_name == "atanh": + x -= 0.1 + r = func(x) + assert r.device == x.device + + +@pytest.mark.parametrize("func_name", elementwise_binary_function_names) +def test_elementwise_function_device_mismatch(func_name): + func = getattr(_elementwise_functions, func_name) + dtypes = elementwise_function_input_types[func_name] + if dtypes in ("floating-point", "real floating-point"): + dtype = float64 + elif dtypes == "boolean": + dtype = xp_bool + else: + dtype = int64 + + a = asarray(1, dtype=dtype, device=CPU_DEVICE) + b = asarray(1, dtype=dtype, device=Device("device1")) + _ = func(a, a) + with pytest.raises(ValueError, match="different devices"): + func(a, b) + + +@pytest.mark.parametrize("func_name", elementwise_function_input_types) +def test_elementwise_function_numpy_scalars(func_name): + """ + Test that NumPy scalars (np.generic) are explicitly disallowed. + + This must notably include np.float64 and np.complex128, which are + subclasses of float and complex respectively, so they need + special treatment in order to be rejected. + """ + func = getattr(_elementwise_functions, func_name) + dtypes = elementwise_function_input_types[func_name] + xp_dtypes = _dtype_categories[dtypes] + np_dtypes = [dtype._np_dtype for dtype in xp_dtypes] + + value = 0.5 if func_name == "atanh" else 1 + for xp_dtype in xp_dtypes: + for np_dtype in np_dtypes: + a = asarray(value, dtype=xp_dtype, device=CPU_DEVICE) + b = np.asarray(value, dtype=np_dtype)[()] - for x in _array_vals(dtypes): if nargs(func) == 2: - # This way we don't have to deal with incompatible - # types of the two arguments. - r = func(x, x) - assert r.device == x.device - + _ = func(a, a) + with pytest.raises(TypeError, match="neither Array nor Python scalars"): + func(a, b) + with pytest.raises(TypeError, match="neither Array nor Python scalars"): + func(b, a) else: - # `atanh` needs a slightly different input value from - # everyone else - if func_name == "atanh": - x -= 0.1 - r = func(x) - assert r.device == x.device + _ = func(a) + with pytest.raises(TypeError, match="allowed"): + func(b) def test_function_types(): @@ -168,9 +225,6 @@ def _array_vals(): for d in _floating_dtypes: yield asarray(1.0, dtype=d) - # Use the latest version of the standard so all functions are included - set_array_api_strict_flags(api_version="2024.12") - for x in _array_vals(): for func_name, types in elementwise_function_input_types.items(): dtypes = _dtype_categories[types] @@ -187,23 +241,23 @@ def _array_vals(): or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes ): - assert_raises(TypeError, func, x, y) + with pytest.raises(TypeError): + func(x, y) if x.dtype not in dtypes or y.dtype not in dtypes: - assert_raises(TypeError, func, x, y) + with pytest.raises(TypeError): + func(x, y) else: if x.dtype not in dtypes: - assert_raises(TypeError, func, x) + with pytest.raises(TypeError): + func(x) def test_bitwise_shift_error(): # bitwise shift functions should raise when the second argument is negative - assert_raises( - ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1])) - ) - assert_raises( - ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) - ) - + with pytest.raises(ValueError): + bitwise_left_shift(asarray([1, 1]), asarray([1, -1])) + with pytest.raises(ValueError): + bitwise_right_shift(asarray([1, 1]), asarray([1, -1])) def test_scalars(): @@ -212,9 +266,6 @@ def test_scalars(): # Also check that binary functions accept (array, scalar) and (scalar, array) # arguments, and reject (scalar, scalar) arguments. - # Use the latest version of the standard so that scalars are actually allowed - set_array_api_strict_flags(api_version="2024.12") - def _array_vals(): for d in _integer_dtypes: yield asarray(1, dtype=d) @@ -256,7 +307,5 @@ def _array_vals(): assert func(s, a) == func(conv_scalar, a) assert func(a, s) == func(a, conv_scalar) - with assert_raises(TypeError): + with pytest.raises(TypeError): func(s, s) - - diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index 2a3a79e..abe1949 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -3,6 +3,8 @@ import array_api_strict as xp from array_api_strict import ArrayAPIStrictFlags +from .._array_object import ALL_DEVICES, CPU_DEVICE, Device +from .._dtypes import _all_dtypes def test_where_with_scalars(): @@ -23,6 +25,10 @@ def test_where_with_scalars(): with pytest.raises(TypeError, match="Two scalars"): xp.where(x == 1, 42, 44) + # The spec does not allow for condition to be scalar + with pytest.raises(TypeError, match="Array"): + xp.where(True, x, x) + def test_where_mixed_dtypes(): # https://github.com/data-apis/array-api-strict/issues/131 @@ -42,3 +48,55 @@ def test_where_f32(): res = xp.where(xp.asarray([True, False]), 1., xp.asarray([2, 2], dtype=xp.float32)) assert res.dtype == xp.float32 + +@pytest.mark.parametrize("device", ALL_DEVICES) +def test_where_device_persists(device): + """Test that the device of the input and output array are the same""" + + cond = xp.asarray([True, False], device=device) + x1 = xp.asarray([1, 2], device=device) + x2 = xp.asarray([3, 4], device=device) + res = xp.where(cond, x1, x2) + assert res.device == device + res = xp.where(cond, 1, x2) + assert res.device == device + res = xp.where(cond, x1, 2) + assert res.device == device + + +@pytest.mark.parametrize( + "cond_device,x1_device,x2_device", + [ + (CPU_DEVICE, CPU_DEVICE, Device("device1")), + (CPU_DEVICE, Device("device1"), CPU_DEVICE), + (Device("device1"), CPU_DEVICE, CPU_DEVICE), + ] +) +def test_where_device_mismatch(cond_device, x1_device, x2_device): + cond = xp.asarray([True, False], device=cond_device) + x1 = xp.asarray([1, 2], device=x1_device) + x2 = xp.asarray([3, 4], device=x2_device) + with pytest.raises(ValueError, match="device"): + xp.where(cond, x1, x2) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_where_numpy_scalars(dtype): + """ + Test that NumPy scalars (np.generic) are explicitly disallowed. + + This must notably include np.float64 and np.complex128, which are + subclasses of float and complex respectively, so they need + special treatment in order to be rejected. + """ + cond = xp.asarray(True) + x1 = xp.asarray(1, dtype=dtype) + x2 = xp.asarray(1, dtype=dtype) + _ = xp.where(cond, x1, x2) + + with pytest.raises(TypeError, match="neither Array nor Python scalars"): + xp.where(cond, x1, x2._array[()]) + with pytest.raises(TypeError, match="neither Array nor Python scalars"): + xp.where(cond, x1._array[()], x2) + with pytest.raises(TypeError, match="must be an Array"): + xp.where(cond._array[()], x1, x2) From 4d23298c8a2c89d04fef5d33ef389e6f5ea70b7b Mon Sep 17 00:00:00 2001 From: Lumir Balhar Date: Fri, 16 May 2025 12:22:55 +0200 Subject: [PATCH 250/252] Fix test_iter with Python 3.14 beta 1 In Python 3.14 beta 1, generator expression returned by the __iter__ method needs to be executed to throw the expected TypeError. Fixes: https://github.com/data-apis/array-api-strict/issues/151 --- array_api_strict/tests/test_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 15d88a9..ae6627a 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -637,7 +637,7 @@ def test_array_namespace(): pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2026.12")) def test_iter(): - pytest.raises(TypeError, lambda: iter(asarray(3))) + pytest.raises(TypeError, lambda: next(iter(asarray(3)))) assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)] assert all_(isinstance(a, Array) for a in iter(ones(3))) assert all_(a.shape == () for a in iter(ones(3))) From 79ff688521e7bc1056c8f88b091de84329fc3b15 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 16 May 2025 12:34:08 +0200 Subject: [PATCH 251/252] MAINT: make reshape require shape to be a tuple --- array_api_strict/_manipulation_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_strict/_manipulation_functions.py b/array_api_strict/_manipulation_functions.py index fe4a608..5cd85d1 100644 --- a/array_api_strict/_manipulation_functions.py +++ b/array_api_strict/_manipulation_functions.py @@ -113,6 +113,8 @@ def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> See its docstring for more information. """ + if not isinstance(shape, tuple): + raise TypeError(f"`shape` must be a tuple of ints; got {shape=} instead.") data = x._array if copy: From 7065adba6ed2a7c7f3126636a88f506cbdcd58d9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Jun 2025 21:43:56 +0000 Subject: [PATCH 252/252] Bump dawidd6/action-download-artifact from 9 to 10 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 9 to 10 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v9...v10) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '10' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index fc61258..4e3efb3 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v9 + uses: dawidd6/action-download-artifact@v10 with: workflow: docs-build.yml name: docs-build