@@ -2349,6 +2349,30 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
23492349 return cls .__new__ (cls )
23502350
23512351
2352+ def _is_torch_array (x ):
2353+ """Check if 'x' is a PyTorch Tensor."""
2354+ try :
2355+ # we're intentionally not attempting to import torch. If somebody
2356+ # has created a torch array, torch should already be in sys.modules
2357+ return isinstance (x , sys .modules ['torch' ].Tensor )
2358+ except Exception : # TypeError, KeyError, AttributeError, maybe others?
2359+ # we're attempting to access attributes on imported modules which
2360+ # may have arbitrary user code, so we deliberately catch all exceptions
2361+ return False
2362+
2363+
2364+ def _is_jax_array (x ):
2365+ """Check if 'x' is a JAX Array."""
2366+ try :
2367+ # we're intentionally not attempting to import jax. If somebody
2368+ # has created a jax array, jax should already be in sys.modules
2369+ return isinstance (x , sys .modules ['jax' ].Array )
2370+ except Exception : # TypeError, KeyError, AttributeError, maybe others?
2371+ # we're attempting to access attributes on imported modules which
2372+ # may have arbitrary user code, so we deliberately catch all exceptions
2373+ return False
2374+
2375+
23522376def _unpack_to_numpy (x ):
23532377 """Internal helper to extract data from e.g. pandas and xarray objects."""
23542378 if isinstance (x , np .ndarray ):
@@ -2363,6 +2387,12 @@ def _unpack_to_numpy(x):
23632387 # so in this case we do not want to return a function
23642388 if isinstance (xtmp , np .ndarray ):
23652389 return xtmp
2390+ if _is_torch_array (x ) or _is_jax_array (x ):
2391+ xtmp = x .__array__ ()
2392+
2393+ # In case __array__() method does not return a numpy array in future
2394+ if isinstance (xtmp , np .ndarray ):
2395+ return xtmp
23662396 return x
23672397
23682398
0 commit comments