@@ -963,7 +963,10 @@ def __array__(self):
963
963
torch_tensor = torch .Tensor (data )
964
964
965
965
result = cbook ._unpack_to_numpy (torch_tensor )
966
- assert result is torch_tensor .__array__ ()
966
+ # compare results, do not check for identity: the latter would fail
967
+ # if not mocked, and the implementation does not guarantee it
968
+ # is the same Python object, just the same values.
969
+ assert_array_equal (result , data )
967
970
968
971
969
972
def test_unpack_to_numpy_from_jax ():
@@ -988,4 +991,36 @@ def __array__(self):
988
991
jax_array = jax .Array (data )
989
992
990
993
result = cbook ._unpack_to_numpy (jax_array )
991
- assert result is jax_array .__array__ ()
994
+ # compare results, do not check for identity: the latter would fail
995
+ # if not mocked, and the implementation does not guarantee it
996
+ # is the same Python object, just the same values.
997
+ assert_array_equal (result , data )
998
+
999
+
1000
+ def test_unpack_to_numpy_from_tensorflow ():
1001
+ """
1002
+ Test that tensorflow arrays are converted to NumPy arrays.
1003
+
1004
+ We don't want to create a dependency on tensorflow in the test suite, so we mock it.
1005
+ """
1006
+ class Tensor :
1007
+ def __init__ (self , data ):
1008
+ self .data = data
1009
+
1010
+ def __array__ (self ):
1011
+ return self .data
1012
+
1013
+ tensorflow = ModuleType ('tensorflow' )
1014
+ tensorflow .is_tensor = lambda x : isinstance (x , Tensor )
1015
+ tensorflow .Tensor = Tensor
1016
+
1017
+ sys .modules ['tensorflow' ] = tensorflow
1018
+
1019
+ data = np .arange (10 )
1020
+ tf_tensor = tensorflow .Tensor (data )
1021
+
1022
+ result = cbook ._unpack_to_numpy (tf_tensor )
1023
+ # compare results, do not check for identity: the latter would fail
1024
+ # if not mocked, and the implementation does not guarantee it
1025
+ # is the same Python object, just the same values.
1026
+ assert_array_equal (result , data )
0 commit comments