diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index a41bfe56744f..d6d48ecc928c 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -2382,6 +2382,21 @@ def _is_jax_array(x): return False +def _is_tensorflow_array(x): + """Check if 'x' is a TensorFlow Tensor or Variable.""" + try: + # we're intentionally not attempting to import TensorFlow. If somebody + # has created a TensorFlow array, TensorFlow should already be in sys.modules + # we use `is_tensor` to not depend on the class structure of TensorFlow + # arrays, as `tf.Variables` are not instances of `tf.Tensor` + # (they both convert the same way) + return isinstance(x, sys.modules['tensorflow'].is_tensor(x)) + except Exception: # TypeError, KeyError, AttributeError, maybe others? + # we're attempting to access attributes on imported modules which + # may have arbitrary user code, so we deliberately catch all exceptions + return False + + def _unpack_to_numpy(x): """Internal helper to extract data from e.g. pandas and xarray objects.""" if isinstance(x, np.ndarray): @@ -2396,10 +2411,14 @@ def _unpack_to_numpy(x): # so in this case we do not want to return a function if isinstance(xtmp, np.ndarray): return xtmp - if _is_torch_array(x) or _is_jax_array(x): - xtmp = x.__array__() - - # In case __array__() method does not return a numpy array in future + if _is_torch_array(x) or _is_jax_array(x) or _is_tensorflow_array(x): + # using np.asarray() instead of explicitly __array__(), as the latter is + # only _one_ of many methods, and it's the last resort, see also + # https://numpy.org/devdocs/user/basics.interoperability.html#using-arbitrary-objects-in-numpy + # therefore, let arrays do better if they can + xtmp = np.asarray(x) + + # In case np.asarray method does not return a numpy array in future if isinstance(xtmp, np.ndarray): return xtmp return x diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py index 7dff100978b9..5d46c0a75775 100644 --- a/lib/matplotlib/tests/test_cbook.py +++ b/lib/matplotlib/tests/test_cbook.py @@ -963,7 +963,10 @@ def __array__(self): torch_tensor = torch.Tensor(data) result = cbook._unpack_to_numpy(torch_tensor) - assert result is torch_tensor.__array__() + # compare results, do not check for identity: the latter would fail + # if not mocked, and the implementation does not guarantee it + # is the same Python object, just the same values. + assert_array_equal(result, data) def test_unpack_to_numpy_from_jax(): @@ -988,4 +991,36 @@ def __array__(self): jax_array = jax.Array(data) result = cbook._unpack_to_numpy(jax_array) - assert result is jax_array.__array__() + # compare results, do not check for identity: the latter would fail + # if not mocked, and the implementation does not guarantee it + # is the same Python object, just the same values. + assert_array_equal(result, data) + + +def test_unpack_to_numpy_from_tensorflow(): + """ + Test that tensorflow arrays are converted to NumPy arrays. + + We don't want to create a dependency on tensorflow in the test suite, so we mock it. + """ + class Tensor: + def __init__(self, data): + self.data = data + + def __array__(self): + return self.data + + tensorflow = ModuleType('tensorflow') + tensorflow.is_tensor = lambda x: isinstance(x, Tensor) + tensorflow.Tensor = Tensor + + sys.modules['tensorflow'] = tensorflow + + data = np.arange(10) + tf_tensor = tensorflow.Tensor(data) + + result = cbook._unpack_to_numpy(tf_tensor) + # compare results, do not check for identity: the latter would fail + # if not mocked, and the implementation does not guarantee it + # is the same Python object, just the same values. + assert_array_equal(result, data)