Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 82fd35d

Browse files
committed
enh: convert TensorFlow to numpy in histplots
1 parent 83b07d4 commit 82fd35d

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

lib/matplotlib/cbook.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2381,6 +2381,20 @@ def _is_jax_array(x):
23812381
# may have arbitrary user code, so we deliberately catch all exceptions
23822382
return False
23832383

2384+
def _is_tensorflow_array(x):
2385+
"""Check if 'x' is a TensorFlow Tensor or Variable."""
2386+
try:
2387+
# we're intentionally not attempting to import TensorFlow. If somebody
2388+
# has created a TensorFlow array, TensorFlow should already be in sys.modules
2389+
# we use `is_tensor` to not depend on the class structure of TensorFlow
2390+
# arrays, as `tf.Variables` are not instances of `tf.Tensor`
2391+
# (but convert the same way)
2392+
return isinstance(x, sys.modules['tensorflow'].is_tensor(x))
2393+
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2394+
# we're attempting to access attributes on imported modules which
2395+
# may have arbitrary user code, so we deliberately catch all exceptions
2396+
return False
2397+
23842398

23852399
def _unpack_to_numpy(x):
23862400
"""Internal helper to extract data from e.g. pandas and xarray objects."""
@@ -2396,10 +2410,14 @@ def _unpack_to_numpy(x):
23962410
# so in this case we do not want to return a function
23972411
if isinstance(xtmp, np.ndarray):
23982412
return xtmp
2399-
if _is_torch_array(x) or _is_jax_array(x):
2400-
xtmp = x.__array__()
2401-
2402-
# In case __array__() method does not return a numpy array in future
2413+
if _is_torch_array(x) or _is_jax_array(x) or _is_tensorflow_array(x):
2414+
# using np.asarray() instead of explicitly __array__(), as the latter is
2415+
# only _one_ of many methods, and it's the last resort, see also
2416+
# https://numpy.org/devdocs/user/basics.interoperability.html#using-arbitrary-objects-in-numpy
2417+
# therefore, let arrays do better if they can
2418+
xtmp = np.asarray(x)
2419+
2420+
# In case np.asarray method does not return a numpy array in future
24032421
if isinstance(xtmp, np.ndarray):
24042422
return xtmp
24052423
return x

lib/matplotlib/tests/test_cbook.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,10 @@ def __array__(self):
963963
torch_tensor = torch.Tensor(data)
964964

965965
result = cbook._unpack_to_numpy(torch_tensor)
966-
assert result is torch_tensor.__array__()
966+
# compare results, do not check for identity: the latter would fail
967+
# if not mocked, and the implementation does not guarantee it
968+
# is the same Python object, just the same values.
969+
assert_array_equal(result, data)
967970

968971

969972
def test_unpack_to_numpy_from_jax():
@@ -988,4 +991,35 @@ def __array__(self):
988991
jax_array = jax.Array(data)
989992

990993
result = cbook._unpack_to_numpy(jax_array)
991-
assert result is jax_array.__array__()
994+
# compare results, do not check for identity: the latter would fail
995+
# if not mocked, and the implementation does not guarantee it
996+
# is the same Python object, just the same values.
997+
assert_array_equal(result, data)
998+
999+
def test_unpack_to_numpy_from_tensorflow():
1000+
"""
1001+
Test that tensorflow arrays are converted to NumPy arrays.
1002+
1003+
We don't want to create a dependency on tensorflow in the test suite, so we mock it.
1004+
"""
1005+
class Tensor:
1006+
def __init__(self, data):
1007+
self.data = data
1008+
1009+
def __array__(self):
1010+
return self.data
1011+
1012+
tensorflow = ModuleType('tensorflow')
1013+
tensorflow.is_tensor = lambda x: isinstance(x, Tensor)
1014+
tensorflow.Tensor = Tensor
1015+
1016+
sys.modules['tensorflow'] = tensorflow
1017+
1018+
data = np.arange(10)
1019+
tf_tensor = tensorflow.Tensor(data)
1020+
1021+
result = cbook._unpack_to_numpy(tf_tensor)
1022+
# compare results, do not check for identity: the latter would fail
1023+
# if not mocked, and the implementation does not guarantee it
1024+
# is the same Python object, just the same values.
1025+
assert_array_equal(result, data)

0 commit comments

Comments
 (0)