@@ -963,7 +963,10 @@ def __array__(self):
963963 torch_tensor = torch .Tensor (data )
964964
965965 result = cbook ._unpack_to_numpy (torch_tensor )
966- assert result is torch_tensor .__array__ ()
966+ # compare results, do not check for identity: the latter would fail
967+ # if not mocked, and the implementation does not guarantee it
968+ # is the same Python object, just the same values.
969+ assert_array_equal (result , data )
967970
968971
969972def test_unpack_to_numpy_from_jax ():
@@ -988,4 +991,36 @@ def __array__(self):
988991 jax_array = jax .Array (data )
989992
990993 result = cbook ._unpack_to_numpy (jax_array )
991- assert result is jax_array .__array__ ()
994+ # compare results, do not check for identity: the latter would fail
995+ # if not mocked, and the implementation does not guarantee it
996+ # is the same Python object, just the same values.
997+ assert_array_equal (result , data )
998+
999+
1000+ def test_unpack_to_numpy_from_tensorflow ():
1001+ """
1002+ Test that tensorflow arrays are converted to NumPy arrays.
1003+
1004+ We don't want to create a dependency on tensorflow in the test suite, so we mock it.
1005+ """
1006+ class Tensor :
1007+ def __init__ (self , data ):
1008+ self .data = data
1009+
1010+ def __array__ (self ):
1011+ return self .data
1012+
1013+ tensorflow = ModuleType ('tensorflow' )
1014+ tensorflow .is_tensor = lambda x : isinstance (x , Tensor )
1015+ tensorflow .Tensor = Tensor
1016+
1017+ sys .modules ['tensorflow' ] = tensorflow
1018+
1019+ data = np .arange (10 )
1020+ tf_tensor = tensorflow .Tensor (data )
1021+
1022+ result = cbook ._unpack_to_numpy (tf_tensor )
1023+ # compare results, do not check for identity: the latter would fail
1024+ # if not mocked, and the implementation does not guarantee it
1025+ # is the same Python object, just the same values.
1026+ assert_array_equal (result , data )
0 commit comments