Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a3ccdfb

Browse files
authored
Merge pull request #28083 from jonas-eschle/je_convert_tf_to_numpy
Convert TensorFlow to numpy for plots
2 parents 2c370be + a6f3635 commit a3ccdfb

File tree

2 files changed

+60
-6
lines changed

2 files changed

+60
-6
lines changed

lib/matplotlib/cbook.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,6 +2382,21 @@ def _is_jax_array(x):
23822382
return False
23832383

23842384

2385+
def _is_tensorflow_array(x):
2386+
"""Check if 'x' is a TensorFlow Tensor or Variable."""
2387+
try:
2388+
# we're intentionally not attempting to import TensorFlow. If somebody
2389+
# has created a TensorFlow array, TensorFlow should already be in sys.modules
2390+
# we use `is_tensor` to not depend on the class structure of TensorFlow
2391+
# arrays, as `tf.Variables` are not instances of `tf.Tensor`
2392+
# (they both convert the same way)
2393+
return isinstance(x, sys.modules['tensorflow'].is_tensor(x))
2394+
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2395+
# we're attempting to access attributes on imported modules which
2396+
# may have arbitrary user code, so we deliberately catch all exceptions
2397+
return False
2398+
2399+
23852400
def _unpack_to_numpy(x):
23862401
"""Internal helper to extract data from e.g. pandas and xarray objects."""
23872402
if isinstance(x, np.ndarray):
@@ -2396,10 +2411,14 @@ def _unpack_to_numpy(x):
23962411
# so in this case we do not want to return a function
23972412
if isinstance(xtmp, np.ndarray):
23982413
return xtmp
2399-
if _is_torch_array(x) or _is_jax_array(x):
2400-
xtmp = x.__array__()
2401-
2402-
# In case __array__() method does not return a numpy array in future
2414+
if _is_torch_array(x) or _is_jax_array(x) or _is_tensorflow_array(x):
2415+
# using np.asarray() instead of explicitly __array__(), as the latter is
2416+
# only _one_ of many methods, and it's the last resort, see also
2417+
# https://numpy.org/devdocs/user/basics.interoperability.html#using-arbitrary-objects-in-numpy
2418+
# therefore, let arrays do better if they can
2419+
xtmp = np.asarray(x)
2420+
2421+
# In case np.asarray method does not return a numpy array in future
24032422
if isinstance(xtmp, np.ndarray):
24042423
return xtmp
24052424
return x

lib/matplotlib/tests/test_cbook.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,10 @@ def __array__(self):
963963
torch_tensor = torch.Tensor(data)
964964

965965
result = cbook._unpack_to_numpy(torch_tensor)
966-
assert result is torch_tensor.__array__()
966+
# compare results, do not check for identity: the latter would fail
967+
# if not mocked, and the implementation does not guarantee it
968+
# is the same Python object, just the same values.
969+
assert_array_equal(result, data)
967970

968971

969972
def test_unpack_to_numpy_from_jax():
@@ -988,4 +991,36 @@ def __array__(self):
988991
jax_array = jax.Array(data)
989992

990993
result = cbook._unpack_to_numpy(jax_array)
991-
assert result is jax_array.__array__()
994+
# compare results, do not check for identity: the latter would fail
995+
# if not mocked, and the implementation does not guarantee it
996+
# is the same Python object, just the same values.
997+
assert_array_equal(result, data)
998+
999+
1000+
def test_unpack_to_numpy_from_tensorflow():
1001+
"""
1002+
Test that tensorflow arrays are converted to NumPy arrays.
1003+
1004+
We don't want to create a dependency on tensorflow in the test suite, so we mock it.
1005+
"""
1006+
class Tensor:
1007+
def __init__(self, data):
1008+
self.data = data
1009+
1010+
def __array__(self):
1011+
return self.data
1012+
1013+
tensorflow = ModuleType('tensorflow')
1014+
tensorflow.is_tensor = lambda x: isinstance(x, Tensor)
1015+
tensorflow.Tensor = Tensor
1016+
1017+
sys.modules['tensorflow'] = tensorflow
1018+
1019+
data = np.arange(10)
1020+
tf_tensor = tensorflow.Tensor(data)
1021+
1022+
result = cbook._unpack_to_numpy(tf_tensor)
1023+
# compare results, do not check for identity: the latter would fail
1024+
# if not mocked, and the implementation does not guarantee it
1025+
# is the same Python object, just the same values.
1026+
assert_array_equal(result, data)

0 commit comments

Comments
 (0)