Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Convert TensorFlow to numpy for plots #28083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions lib/matplotlib/cbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,6 +2382,21 @@ def _is_jax_array(x):
return False


def _is_tensorflow_array(x):
"""Check if 'x' is a TensorFlow Tensor or Variable."""
try:
# we're intentionally not attempting to import TensorFlow. If somebody
# has created a TensorFlow array, TensorFlow should already be in sys.modules
# we use `is_tensor` to not depend on the class structure of TensorFlow
# arrays, as `tf.Variables` are not instances of `tf.Tensor`
# (they both convert the same way)
return isinstance(x, sys.modules['tensorflow'].is_tensor(x))
except Exception: # TypeError, KeyError, AttributeError, maybe others?
# we're attempting to access attributes on imported modules which
# may have arbitrary user code, so we deliberately catch all exceptions
return False


def _unpack_to_numpy(x):
"""Internal helper to extract data from e.g. pandas and xarray objects."""
if isinstance(x, np.ndarray):
Expand All @@ -2396,10 +2411,14 @@ def _unpack_to_numpy(x):
# so in this case we do not want to return a function
if isinstance(xtmp, np.ndarray):
return xtmp
if _is_torch_array(x) or _is_jax_array(x):
xtmp = x.__array__()

# In case __array__() method does not return a numpy array in future
if _is_torch_array(x) or _is_jax_array(x) or _is_tensorflow_array(x):
# using np.asarray() instead of explicitly __array__(), as the latter is
# only _one_ of many methods, and it's the last resort, see also
# https://numpy.org/devdocs/user/basics.interoperability.html#using-arbitrary-objects-in-numpy
# therefore, let arrays do better if they can
xtmp = np.asarray(x)

# In case np.asarray method does not return a numpy array in future
if isinstance(xtmp, np.ndarray):
return xtmp
return x
Expand Down
39 changes: 37 additions & 2 deletions lib/matplotlib/tests/test_cbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,10 @@ def __array__(self):
torch_tensor = torch.Tensor(data)

result = cbook._unpack_to_numpy(torch_tensor)
assert result is torch_tensor.__array__()
# compare results, do not check for identity: the latter would fail
# if not mocked, and the implementation does not guarantee it
# is the same Python object, just the same values.
assert_array_equal(result, data)


def test_unpack_to_numpy_from_jax():
Expand All @@ -988,4 +991,36 @@ def __array__(self):
jax_array = jax.Array(data)

result = cbook._unpack_to_numpy(jax_array)
assert result is jax_array.__array__()
# compare results, do not check for identity: the latter would fail
# if not mocked, and the implementation does not guarantee it
# is the same Python object, just the same values.
assert_array_equal(result, data)


def test_unpack_to_numpy_from_tensorflow():
"""
Test that tensorflow arrays are converted to NumPy arrays.

We don't want to create a dependency on tensorflow in the test suite, so we mock it.
"""
class Tensor:
def __init__(self, data):
self.data = data

def __array__(self):
return self.data

tensorflow = ModuleType('tensorflow')
tensorflow.is_tensor = lambda x: isinstance(x, Tensor)
tensorflow.Tensor = Tensor

sys.modules['tensorflow'] = tensorflow

data = np.arange(10)
tf_tensor = tensorflow.Tensor(data)

result = cbook._unpack_to_numpy(tf_tensor)
# compare results, do not check for identity: the latter would fail
# if not mocked, and the implementation does not guarantee it
# is the same Python object, just the same values.
assert_array_equal(result, data)
Loading