From 82fd35de3692f77d7d4679b6a2f806f209102ccd Mon Sep 17 00:00:00 2001 From: Jonas Eschle Date: Mon, 15 Apr 2024 13:15:41 -0400 Subject: [PATCH 1/3] enh: convert TensorFlow to numpy in histplots --- lib/matplotlib/cbook.py | 26 ++++++++++++++++---- lib/matplotlib/tests/test_cbook.py | 38 ++++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index a41bfe56744f..117f8dac5da1 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -2381,6 +2381,20 @@ def _is_jax_array(x): # may have arbitrary user code, so we deliberately catch all exceptions return False +def _is_tensorflow_array(x): + """Check if 'x' is a TensorFlow Tensor or Variable.""" + try: + # we're intentionally not attempting to import TensorFlow. If somebody + # has created a TensorFlow array, TensorFlow should already be in sys.modules + # we use `is_tensor` to not depend on the class structure of TensorFlow + # arrays, as `tf.Variables` are not instances of `tf.Tensor` + # (but convert the same way) + return isinstance(x, sys.modules['tensorflow'].is_tensor(x)) + except Exception: # TypeError, KeyError, AttributeError, maybe others? + # we're attempting to access attributes on imported modules which + # may have arbitrary user code, so we deliberately catch all exceptions + return False + def _unpack_to_numpy(x): """Internal helper to extract data from e.g. pandas and xarray objects.""" @@ -2396,10 +2410,14 @@ def _unpack_to_numpy(x): # so in this case we do not want to return a function if isinstance(xtmp, np.ndarray): return xtmp - if _is_torch_array(x) or _is_jax_array(x): - xtmp = x.__array__() - - # In case __array__() method does not return a numpy array in future + if _is_torch_array(x) or _is_jax_array(x) or _is_tensorflow_array(x): + # using np.asarray() instead of explicitly __array__(), as the latter is + # only _one_ of many methods, and it's the last resort, see also + # https://numpy.org/devdocs/user/basics.interoperability.html#using-arbitrary-objects-in-numpy + # therefore, let arrays do better if they can + xtmp = np.asarray(x) + + # In case np.asarray method does not return a numpy array in future if isinstance(xtmp, np.ndarray): return xtmp return x diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py index 7dff100978b9..43761203d800 100644 --- a/lib/matplotlib/tests/test_cbook.py +++ b/lib/matplotlib/tests/test_cbook.py @@ -963,7 +963,10 @@ def __array__(self): torch_tensor = torch.Tensor(data) result = cbook._unpack_to_numpy(torch_tensor) - assert result is torch_tensor.__array__() + # compare results, do not check for identity: the latter would fail + # if not mocked, and the implementation does not guarantee it + # is the same Python object, just the same values. + assert_array_equal(result, data) def test_unpack_to_numpy_from_jax(): @@ -988,4 +991,35 @@ def __array__(self): jax_array = jax.Array(data) result = cbook._unpack_to_numpy(jax_array) - assert result is jax_array.__array__() + # compare results, do not check for identity: the latter would fail + # if not mocked, and the implementation does not guarantee it + # is the same Python object, just the same values. + assert_array_equal(result, data) + +def test_unpack_to_numpy_from_tensorflow(): + """ + Test that tensorflow arrays are converted to NumPy arrays. + + We don't want to create a dependency on tensorflow in the test suite, so we mock it. + """ + class Tensor: + def __init__(self, data): + self.data = data + + def __array__(self): + return self.data + + tensorflow = ModuleType('tensorflow') + tensorflow.is_tensor = lambda x: isinstance(x, Tensor) + tensorflow.Tensor = Tensor + + sys.modules['tensorflow'] = tensorflow + + data = np.arange(10) + tf_tensor = tensorflow.Tensor(data) + + result = cbook._unpack_to_numpy(tf_tensor) + # compare results, do not check for identity: the latter would fail + # if not mocked, and the implementation does not guarantee it + # is the same Python object, just the same values. + assert_array_equal(result, data) From 1ee7bc01f16882e464d0e2b616267811723052c7 Mon Sep 17 00:00:00 2001 From: Jonas Eschle Date: Mon, 15 Apr 2024 13:18:26 -0400 Subject: [PATCH 2/3] enh: convert TensorFlow to numpy in histplots --- lib/matplotlib/cbook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index 117f8dac5da1..bd84b7cf4555 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -2388,7 +2388,7 @@ def _is_tensorflow_array(x): # has created a TensorFlow array, TensorFlow should already be in sys.modules # we use `is_tensor` to not depend on the class structure of TensorFlow # arrays, as `tf.Variables` are not instances of `tf.Tensor` - # (but convert the same way) + # (they both convert the same way) return isinstance(x, sys.modules['tensorflow'].is_tensor(x)) except Exception: # TypeError, KeyError, AttributeError, maybe others? # we're attempting to access attributes on imported modules which From a6f3635ff081424d1ca529c21b31261802f12a57 Mon Sep 17 00:00:00 2001 From: Jonas Eschle Date: Mon, 15 Apr 2024 13:26:18 -0400 Subject: [PATCH 3/3] chore: fix style --- lib/matplotlib/cbook.py | 1 + lib/matplotlib/tests/test_cbook.py | 1 + 2 files changed, 2 insertions(+) diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index bd84b7cf4555..d6d48ecc928c 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -2381,6 +2381,7 @@ def _is_jax_array(x): # may have arbitrary user code, so we deliberately catch all exceptions return False + def _is_tensorflow_array(x): """Check if 'x' is a TensorFlow Tensor or Variable.""" try: diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py index 43761203d800..5d46c0a75775 100644 --- a/lib/matplotlib/tests/test_cbook.py +++ b/lib/matplotlib/tests/test_cbook.py @@ -996,6 +996,7 @@ def __array__(self): # is the same Python object, just the same values. assert_array_equal(result, data) + def test_unpack_to_numpy_from_tensorflow(): """ Test that tensorflow arrays are converted to NumPy arrays.