-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Open up check_array
and BaseEstimator._validate_data
to overriding xp.asarray
with an additional callable parameter asarray_fn
#25433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
check_array
and BaseEstimator._validate_data
to overriding _asarray_with_order
with an additional callable parameter asarray_fn
check_array
and BaseEstimator._validate_data
to overriding xp.asarray
with an additional callable parameter asarray_fn
Example given in #25434 |
It was discussed at the triage meeting:
Alternatively, we could use the |
After working on Array API a little more, I propose adding a namespace This proposal should enable the Engine/Plugin API to use the Array API code paths for validation in |
That sounds great. it would answer the usecases I have reported with |
Describe the workflow you want to enable
Some people (including @betatim @ogrisel @jjerphan and I) have been devising a plugin system that would open up
sklearn
estimators to other external implementations, and in particular implementations with GPU backends - see #22438 .Some of the plugins we're considering can materialize the data in memory with an array library that is compatible with the Array API - namely CuPy and dpctl.tensor.
One thing we've found is that internally those plugins can benefit from using directly
BaseEstimator._validate_data
andcheck_array
fromscikit-learn
to do the data acceptation and preparation step.Describe your proposed solution
To enable this it would be nice to be able to pass a
asarray_fn
tocheck_array
and_validate_data
, that would be called instead ofxp.asarray
in_asarray_with_order
. This would enable the plugin to convert directly the input data to an array that the plugin supports (e.g.cupy
ordpctl.tensor
) while still benefiting from reusing existing validation code incheck_array
.The override can be necessary in case the
asarray
method from the array library implements a superset of the array api that is necessary for the plugin, but is currently not used bycheck_array
because it's not part of the array api (for instance, theorder
argument isn't passed toasarray
for array libraries other thannumpy
)The text was updated successfully, but these errors were encountered: