-
-
Notifications
You must be signed in to change notification settings - Fork 7.9k
Update _unpack_to_numpy
function to convert JAX and PyTorch arrays to NumPy
#25887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update _unpack_to_numpy
function to convert JAX and PyTorch arrays to NumPy
#25887
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for opening your first PR into Matplotlib!
If you have not heard from us in a while, please feel free to ping @matplotlib/developers
or anyone who has commented on the PR. Most of our reviewers are volunteers and sometimes things fall through the cracks.
You can also join us on gitter for real-time discussion.
For details on testing, writing docs, and our review process, please see the developer guide
We strive to be a welcoming and open project. Please follow our Code of Conduct.
Please remove the unrelated changes. |
Hmm, it seems like this is probably the clearest indication of the test failures:
As the matplotlib/lib/matplotlib/tests/test_units.py Lines 41 to 42 in 10e8bf1
the units will be dropped and the test fail. I have no idea about the unit support though, so not really clear how to get around it... |
FYI I don't think you're supposed to call
But not sure how much it matters in practice, or whether another library like matplotlib is a "user" in this context. |
This is hitting the same reason why we do not blindly call Between this and the note about not calling |
Thank you for a diverse set of feedback, @oscargus, @mwaskom and @tacaswell. So, what do you think would be a better way to go ahead?
if str(type(x)) == "<class 'torch.Tensor'>":
return x.numpy()
if str(type(x)) == "<class 'jaxlib.xla_extension.ArrayImpl'>":
return np.asarray(x) Or something completely different from these directions? |
Do not trust this fully, but I think that checking if there is a |
I am very 👎 on added string checking of types. Unfortunately we are in a awkward bind where we very permissive in what we take as input, do not want to depend on any imports, and due to the diversity of input we can not treat them all the same. |
I have not checked, but maybe there is something in the python array API standard - At least it would belong there. |
Would |
There's a third option not mentioned here: use |
My sense is that the direction data libraries would like to move is for exchange via the |
We discussed the history of this a bit on the dev call today. I think the below is close to correct, but I could be misunderstanding: We officially support numpy arrays as inputs to our data plotting functions. We also officially support mechanisms for objects to get passed that contain "unit" information (eg pint). Somewhat confusingly, this unit information is sometimes at the container level (eg pint), and sometimes at the element level, or the dtype of the elements (eg nparrays of datetime64, or lists of strings). We unofficially support xarray and pandas objects, assume they have no units, by calling their At the level that After we have checked for units, we usually call I'm not sure what the path out of the conundrum is - I somewhat feel the unit conversion interface should have been less magical, and more explicit, so users would have to specify a converter on an axis manually, rather than us guessing the converter. |
That isn't quite the right thing; the array API standard is meant to use "native functions", so this method is what you'd use if you want to retrieve the I agree with @jakevdp and @mwaskom that use of
If units libraries silently lose data when |
From an interface perspective, it's reasonable to rely on |
@patel-zeel in my comment #25882 (comment), I hadn't considered the unit problem. That indeed makes the problem much more complicated. To all: To summarize and comment on the above proposed solutions:
There is no easy solution here. Special situations sometimes require special measures: Given all the boundary conditions, I'd be +0.5 on type-checking by string, despite @tacaswell being strongly 👎 on this. Usually, I'd agree, but that's the only realistic way forward. 1. won't happen; 2. is introducing strong coupling, which IMHO is worse; 4. won't realistcally happen, because we don't have the capacity for it. So what would we buy into with type-checking by string. Drawbacks are (1) the str comparison is slower than a type check - but that should be negligible; and (2) It's brittle because the str representation could change without us noticing and then the functionality would be broken. To alleviate (2), we could use The only other alternative would be to tell users to convert their JAX/Torch arrays explicitly (or live with the performance impact). But that'd be not user friendly. |
This is what Kyle is working on, but is 1-2 years off, but I don't think we should wait for it. I am convinced by @timhoffm 's analysis and am also +0.5 on string typing now. |
Couldn’t you accomplish option 2 without the performance impact by looking to see if certain modules are already in |
@timhoffm Thanks for the review and suggestion for testing. I have applied the suggested changes and implemented the first version of testing for this feature. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good – one comment is that perhaps we should abstract this a bit: maybe have a configurable list of external objects to look for (e.g. external_objects = ['torch.Tensor', 'jax.Array']
) and write just a single function that loops through and checks these.
Then in the future, if someone wanted to add cupy.ndarray
or something it would be just a tiny change, and it could even be done at runtime if we wanted to provide that API.
@jakevdp How'd you suggest abstracting this? Some relevant points:
import cupy
array = cupy.array([1, 2, 3.0])
np_array = array.__array__() # fails
# np_array = array.get() # this works Output: ---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[3], [line 1](vscode-notebook-cell:?execution_count=3&line=1)
----> [1](vscode-notebook-cell:?execution_count=3&line=1) array.__array__()
File cupy/_core/core.pyx:1475, in cupy._core.core._ndarray_base.__array__()
TypeError: Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.
Considering both of the above cases, discussion in #25882, and discussion in this PR, would it be better to provide two methods, import sys
import numpy as np
from abc import ABC, abstractmethod
class TypeArray(ABC):
@abstractmethod
def is_type(x):
pass
@abstractmethod
def to_numpy(x):
pass
class TorchArray(TypeArray):
def is_type(x):
"""Check if 'x' is a PyTorch Tensor."""
try:
# we're intentionally not attempting to import torch. If somebody
# has created a torch array, torch should already be in sys.modules
return isinstance(x, sys.modules['torch'].Tensor)
except Exception: # TypeError, KeyError, AttributeError, maybe others?
# we're attempting to access attributes on imported modules which
# may have arbitrary user code, so we deliberately catch all exceptions
return False
def to_numpy(x):
"""Convert to NumPy array"""
# preferred over `.numpy(force=True)` to support older PyTorch versions.
return x.detach().cpu().numpy()
class JaxArray(TypeArray):
def is_type(x):
"""Check if 'x' is a JAX array."""
try:
# we're intentionally not attempting to import jax. If somebody
# has created a jax array, jax should already be in sys.modules
return isinstance(x, sys.modules['jax'].Array)
except Exception: # TypeError, KeyError, AttributeError, maybe others?
# we're attempting to access attributes on imported modules which
# may have arbitrary user code, so we deliberately catch all exceptions
return False
def to_numpy(x):
"""Convert to NumPy array"""
return x.__array__() # works even if `x` is on GPU
class CupyArray(TypeArray):
def is_type(x):
"""Check if 'x' is a CuPy array."""
try:
# we're intentionally not attempting to import cupy. If somebody
# has created a cupy array, cupy should already be in sys.modules
return isinstance(x, sys.modules['cupy'].ndarray)
except Exception: # TypeError, KeyError, AttributeError, maybe others?
# we're attempting to access attributes on imported modules which
# may have arbitrary user code, so we deliberately catch all exceptions
return False
def to_numpy(x):
"""Convert to NumPy array"""
return x.get()
external_objects = [TorchArray, JaxArray, CupyArray]
def _unpack_to_numpy(x):
"""Internal helper to extract data from e.g. pandas and xarray objects."""
if isinstance(x, np.ndarray):
# If numpy, return directly
return x
if hasattr(x, 'to_numpy'):
# Assume that any to_numpy() method actually returns a numpy array
return x.to_numpy()
if hasattr(x, 'values'):
xtmp = x.values
# For example a dict has a 'values' attribute, but it is not a property
# so in this case we do not want to return a function
if isinstance(xtmp, np.ndarray):
return xtmp
for obj in external_objects:
assert isinstance(obj, TypeArray)
if obj.is_type(x):
xtmp = obj.to_numpy(x)
# In case to_numpy() doesn't return a numpy array in future
if isinstance(xtmp, np.ndarray):
return xtmp
return x |
IMHO further abstraction would be premature. The current implementation is simple and good enough. Paraphrased from https://youtu.be/UANN2Eu6ZnM?feature=shared
This has two major advantages: 1. You don't create abstractions that you don't use. 2. When you build the abstraction, you have three concrete use cases, so it's more likely the abstraction is suitable. |
I didn't mean to suggest any complicated abstraction; I was thinking something simple like this: ARRAYLIKE_OBJECTS = [('jax', 'Array'), ('torch', 'Tensor')]
def maybe_convert_to_array(x):
for mod, name in ARRAYLIKE_OBJECTS:
try:
is_array = isinstance(x, getattr(sys.modules[mod], name)):
except Exception:
pass
else:
if is_array: return np.asarray(x)
return x It reduces duplication of logic and makes it easier to add additional types if/when needed. |
@@ -2358,6 +2382,12 @@ def _unpack_to_numpy(x): | |||
# so in this case we do not want to return a function | |||
if isinstance(xtmp, np.ndarray): | |||
return xtmp | |||
if _is_torch_array(x) or _is_jax_array(x): | |||
xtmp = x.__array__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason not to do xtmp = np.asarray(x)
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these are equivalent for the handled cases. While pandas claims you shouldn't call __array__
directly, I haven't found any official recommendation for it in numpy (which defines the __array__
API).
For me, either works. __array__
is more explicit, which can be a good thing, but might be too low level. OTOH I don't think np.asarray()
will change the implementation so that the asarray
abstraction would be safer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jakevdp @timhoffm What's your opinion about using x.numpy(force=True)
or x.detach().cpu().numpy()
for PyTorch? As a user, I'd find this change useful for day-to-day coding since it saves me from manually writing it for every array I want to plot with matplotlib.
if _is_torch_array(x):
xtmp = x.numpy(force=True) # or x.detach().cpu().numpy()
if _is_jax_array(x):
xtmp = x.__array__()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable to use the x.numpy
method provided by torch. Though I have to say, I don't know whether __array__
would do something different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@timhoffm To add more context, in JAX, .__array__()
method converts a JAX array to a NumPy array irrespective of whether the JAX array is on CPU, GPU, TPU and probably other hardware accelerators. OTOH in PyTorch, they don't do it by default due to uncertainty about performance impact (full discussion is in this issue). So, when we do x.numpy(force=True)
, it forcefully converts it to NumPy irrespective of the device of the array (handles few others cases as well, e.g. if array has a computation graph for backpropagation then x.detach()
needs to be called first). I am not sure if this should be handled in a separate issue or we can use x.numpy(force=True)
in this PR itself but I'm sure that PyTorch users would love this change to avoid writing x.detach().cpu().numpy()
every time they plot a PyTorch array.
cc @jakevdp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My hesitation about this is that memory overflows will get shunted to Matpolotlib, and then we will get the bug reports, whereas if people did x.*.numpy()
in their own code, they would see what the problem is. Jax arrays can be far larger than memory allows, and Matplotlib blindly unpacking them for naive users seems like a bad idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though I have to say, I don't know whether
__array__
would do something different.
__array__
is the last thing tried by NumPy when a user would call np.asarray
, after the buffer protocol and __array_interface__
(as documented for example here). So they're not equivalent, and calling np.asarray
is idiomatic.
Yes, that would be marginally better, and can optionally be done. In the interest of not endlessly bikeshedding the PR, I have accepted the current version. After all, this is all internal and can be refactored any time. |
OK, apologies for not paying attention to this properly, but hard-coding certain libraries to have a cut around seems incorrect and brittle to me. What criteria will we have if we get requests to support other libraries? I think the fundamental problem is with where I think it would be a mistake to change The methods where this gets used are all binning methods. The |
This is indeed a workaround. The proper way would be for Otherwise, I think this PR is good enough to be included in 3.9. It achieves the desired speedup and otherwise is completely internal, so we can still change the implementation whenever we like.
Case-by-case. Support them if it's easily possible, don't if it's not. There's little maintanance burden and no API liability. Also, I don't expect that there would be more than a hand full of such libraries. |
@timhoffm @jakevdp Getting back to this after a while. To summarize the pending changes:
I think the accepted changes are optimal based on the current circumstances. |
I'm going to merge this to move forward. On one hand, I think it is reasonable to expect users to get their data back to the cpu and in numpy before we plot it, but we have gotten enough bugs and this is a light enough touch. If using This also sets a reasonble pattern for how we would add support for cupy / the next big library. |
Thank you for your work on this @patel-zeel and congratulations on your first merged Matplotlib PR 🎉 I hope we hear from you again. |
Thank you, @tacaswell, and all the contributors to this PR. This PR has taught me a lot of dev tricks. This wouldn't have been possible without everybody's diverse inputs, @jakevdp's robust ideas, and @timhoffm's pivotal role. |
PR summary
This PR closes #25882 by modifying the
_unpack_to_numpy
function. The main changes are the following.if
condition to check if an object has__array__
method and the new object returned by accessing__array__
method is a NumPy array.if
condition to capture NumPy scalars which were not captured by thendarray
check earlier. This was needed because otherwise NumPy scalar objects get infinitely stuck into__array__
check since they get converted tondarray
upon calling__array__
method on them.PR checklist