-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
DOC Add info on 'array-like' array API inputs when array_api_dispatch=False
#32676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Another thought that I am 50/50 on - currently when an array that is not able to be converted via What do we think about adding something along the lines of - "if you wish to enable array API support and get {namespace} arrays as output, set the configuration At the moment this change would be a bit complicated because we perform the |
|
I think it would be beneficial to mention that array API dispatch should be activated. I think we try to handle non NumPy arrays with |
|
Yeah I agree, this seems like a niche/odd thing that we do, but we wouldn't really recommend anyone to use it. I've amended the wording. I think it is important to document, so people know this happens, because it can be quite unexpected to get a different namespaced array than what you've input e.g., you've forgotten to set |
array_api_dispatch=Falsearray_api_dispatch=False
|
I'm not sure if the array API section is the right place for this information. I think we should have this in a more general section of the scikit-learn docs, possibly we already state it there :-/ It has happened that as part of the array API we've changed what happens when you pass torch arrays to scikit-learn. In which case we get reports about it breaking. This makes me think that it is something people rely and is possible on purpose. But I don't fully know the "rule of thumb" for what works and how. I think it is something like "if it looks enough like a Numpy array" and/or if it supports |
My naive guess would be arrays on CPU device can be converted. I have tried torch 'cpu' and jax (
Interesting, can you point me to any references?
Tricky to look up. In the SVM user guide we state:
In "Developing scikit-learn estimators" -> "Input validation" we state:
I looked for "array " and "asarray" in our I am happy to add a note about this behaviour in the metrics section of the user guide and link it to here? |
|
#29107 is an example of a bug like this. |
lucyleeow
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've realised that we (probably) do np.array or np.asarray on all inputs that we specify as array-like. Thus I amended the glossary entry for array-like to include pytorch tensors on cpu and JAX arrays as examples. Also amended working in the array API user guide.
I couldn't find another suitable place in the user guide to add this info, so I think I will leave for now and hope the glossary is adequate for now.
doc/modules/array_api.rst
Outdated
| to avoid having to reset it to `False` it at the end of every code snippet, so as to | ||
| not affect the rest of the documentation. | ||
|
|
||
| Scikit-learn accepts :term:`array-like` inputs for all :mod:`~sklearn.metrics` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all :mod:
~sklearn.metrics
Is this true..?
doc/modules/array_api.rst
Outdated
| into NumPy arrays using :func:`numpy.asarray` (or :func:`numpy.array`). | ||
| While this will successfully convert some array API inputs (e.g., JAX array), | ||
| we generally recommend setting `array_api_dispatch=True` when using array API inputs. | ||
| Note when `array_api_dispatch=False`, array outputs will be NumPy arrays. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This last line is probably redundant, I am happy to remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's important to somehow highlight that when array_api_dispatch=False, NumPy conversion can fail, for instance, calling fit on a torch tensor allocated on a GPU will raise an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a sentence, not 100% happy with it so let me know what you think.
I've also removed the last line, which was redundant, I think thats what @betatim 's thumbs up meant 😄
array_api_dispatch=Falsearray_api_dispatch=False
array_api_dispatch=Falsearray_api_dispatch=False
betatim
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More than super good enough!
|
I'll merge it. Olivier can comment or make a follow up PR if there is something we missed |
|
I don't think this is worth doing something about it but
I think the recommendation is quite mild so that's fine but passing a PyTorch cpu tensor has been working for ever and I guess people kind of expect the implicit conversion to numpy array 1. Something I did not think about until now, is that when array API becomes the default the output will change to a PyTorch tensor and there may be one workflow somewhere that breaks but it's probably quite unlikely to happen? Footnotes
|
|
I think this is something to consider as part of making array API the default. In some way the pytorch on CPU -> Numpy workflow is part of the scikit-learn API, in some way you could also argue that it is undocumented behaviour :-/ |
It isn't explicitly stated anywhere, just implied by the glossary entry: "array-like is any type object for which :func: Maybe I should amend this glossary entry to make it more explicit that inputs will be converted to numpy arrays, and array outputs will be numpy arrays...? |
|
I'd not do anything for now |
Reference Issues/PRs
Related #30454
Found out in #32600 (comment)
What does this implement/fix? Explain your changes.
Adds info on what happens when non-NumPy array input occurs with
array_api_dispatch=FalseI only realised that we do this and I think it would be nice if it was mentioned in the docs.
I thought I'd add a section on the
array_api_dispatchas it no longer seemed to fit under the 'Example usage' sectionAny other comments?
cc @OmarManzoor @lesteve