@@ -963,7 +963,10 @@ def __array__(self):
963
963
torch_tensor = torch .Tensor (data )
964
964
965
965
result = cbook ._unpack_to_numpy (torch_tensor )
966
- assert result is torch_tensor .__array__ ()
966
+ # compare results, do not check for identity: the latter would fail
967
+ # if not mocked, and the implementation does not guarantee it
968
+ # is the same Python object, just the same values.
969
+ assert_array_equal (result , data )
967
970
968
971
969
972
def test_unpack_to_numpy_from_jax ():
@@ -988,4 +991,35 @@ def __array__(self):
988
991
jax_array = jax .Array (data )
989
992
990
993
result = cbook ._unpack_to_numpy (jax_array )
991
- assert result is jax_array .__array__ ()
994
+ # compare results, do not check for identity: the latter would fail
995
+ # if not mocked, and the implementation does not guarantee it
996
+ # is the same Python object, just the same values.
997
+ assert_array_equal (result , data )
998
+
999
+ def test_unpack_to_numpy_from_tensorflow ():
1000
+ """
1001
+ Test that tensorflow arrays are converted to NumPy arrays.
1002
+
1003
+ We don't want to create a dependency on tensorflow in the test suite, so we mock it.
1004
+ """
1005
+ class Tensor :
1006
+ def __init__ (self , data ):
1007
+ self .data = data
1008
+
1009
+ def __array__ (self ):
1010
+ return self .data
1011
+
1012
+ tensorflow = ModuleType ('tensorflow' )
1013
+ tensorflow .is_tensor = lambda x : isinstance (x , Tensor )
1014
+ tensorflow .Tensor = Tensor
1015
+
1016
+ sys .modules ['tensorflow' ] = tensorflow
1017
+
1018
+ data = np .arange (10 )
1019
+ tf_tensor = tensorflow .Tensor (data )
1020
+
1021
+ result = cbook ._unpack_to_numpy (tf_tensor )
1022
+ # compare results, do not check for identity: the latter would fail
1023
+ # if not mocked, and the implementation does not guarantee it
1024
+ # is the same Python object, just the same values.
1025
+ assert_array_equal (result , data )
0 commit comments