Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MNT remove take fn in array_api wrapper #27939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

fcharras
Copy link
Contributor

What does this implement/fix? Explain your changes.

Only a bit of cleaning, this method is no more needed since take has been adopted by the Array API standard.

Any other comments?

The only place where xp.take is used is in sklearn.utils.extmath.svd_flip.

Copy link

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: ede5ec5. Link to the linter CI: here

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the array api support is still experimental, I assume that we can expect our users to always use the latest version of the array-api-compat soft dependency. Therefore I assume that this kind of change is fine as long as all existing tests pass.

Could you please confirm that it works with pytorch and cupy on a CUDA host?

@ogrisel
Copy link
Member

ogrisel commented Dec 11, 2023

/cc @thomasjpfan

@fcharras
Copy link
Contributor Author

isdtype can probably be removed and also replaced by xp.isdtype safely by now.

@ogrisel
Copy link
Member

ogrisel commented Dec 11, 2023

isdtype can probably be removed and also replaced by xp.isdtype safely by now.

Let's open a dedicated PR to be able to merge those two simplication independently in case of one is breaking tests but not the other.

@fcharras
Copy link
Contributor Author

fcharras commented Dec 11, 2023

Since the array api support is still experimental, I assume that we can expect our users to always use the latest version of the array-api-compat soft dependency. Therefore I assume that this kind of change is fine as long as all existing tests pass.

That's also my thoughts.

Could you please confirm that it works with pytorch and cupy on a CUDA host?

Tests OK. The relevant tests are the tests for PCA, it's where svd_flip is used with Array API inputs (see former PCA PR for Array API compliancy):

With cupy

pytest -v sklearn/decomposition/tests/test_pca.py -k cupy

=============================================== test session starts ===============================================
platform linux -- Python 3.10.12, pytest-7.4.3, pluggy-1.3.0 -- /data/parietal/store3/work/fcharras/mambaforge/envs/cuml_env/bin/python3.10
cachedir: .pytest_cache
rootdir: /data/parietal/store3/work/fcharras/scikit-learn
configfile: setup.cfg
collected 309 items / 293 deselected / 16 selected                                                                

sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-cupy-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-cupy.array_api-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_get_precision-cupy-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_get_precision-cupy.array_api-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_input_and_values-cupy-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_input_and_values-cupy.array_api-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_get_precision-cupy-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_get_precision-cupy.array_api-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-cupy-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-cupy.array_api-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-cupy-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-cupy.array_api-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-cupy-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-cupy.array_api-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-cupy-None-None] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-cupy.array_api-None-None] PASSED

================================== 16 passed, 293 deselected, 1 warning in 2.93s ==================================
With pytorch

pytest -v sklearn/decomposition/tests/test_pca.py -k torch

=========================================================================================================== test session starts ============================================================================================================
platform linux -- Python 3.11.6, pytest-7.4.3, pluggy-1.3.0 -- /data/parietal/store3/work/fcharras/mambaforge/envs/pytorch_env_3/bin/python3.11
cachedir: .pytest_cache
rootdir: /data/parietal/store3/work/fcharras/scikit-learn
configfile: setup.cfg
collected 309 items / 269 deselected / 40 selected                                                                                                                                                                                         

sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-torch-cpu-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-torch-cpu-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-torch-cuda-float64] 

PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-torch-cuda-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-torch-mps-float32] SKIPPED (Skipping MPS device test because
PYTORCH_ENABLE_MPS_FALLBACK is not set.)
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_get_precision-torch-cpu-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_get_precision-torch-cpu-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_get_precision-torch-cuda-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_get_precision-torch-cuda-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_get_precision-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK
is not set.)
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_input_and_values-torch-cpu-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_input_and_values-torch-cpu-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_input_and_values-torch-cuda-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_input_and_values-torch-cuda-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_input_and_values-torch-mps-float32] SKIPPED (Skipping MPS device test because
PYTORCH_ENABLE_MPS_FALLBACK is not set.)
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_get_precision-torch-cpu-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_get_precision-torch-cpu-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_get_precision-torch-cuda-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_get_precision-torch-cuda-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_get_precision-torch-mps-float32] SKIPPED (Skipping MPS device test because
PYTORCH_ENABLE_MPS_FALLBACK is not set.)
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-torch-cpu-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-torch-cpu-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-torch-cuda-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-torch-cuda-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-torch-mps-float32] SKIPPED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-torch-cpu-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-torch-cpu-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-torch-cuda-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-torch-cuda-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-torch-mps-float32] SKIPPED (Skipping
MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-torch-cpu-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-torch-cpu-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-torch-cuda-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-torch-cuda-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK
is not set.)
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-torch-cpu-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-torch-cpu-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-torch-cuda-float64] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-torch-cuda-float32] PASSED
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-torch-mps-float32] SKIPPED (Skipping MPS device test because
PYTORCH_ENABLE_MPS_FALLBACK is not set.)

============================================================================================== 32 passed, 8 skipped, 269 deselected in 20.88s ==============================================================================================

Let's open a dedicated PR to be able to merge those two simplication independently in case of one is breaking tests but not the other.

ok

@fcharras
Copy link
Contributor Author

isdtype can probably be removed

It cannot, it must at least be implemented for numpy (np.isdtype does not exist)

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for posting the test results. Still +1 for merge this PR in its current state.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@thomasjpfan thomasjpfan merged commit 8e10cd7 into scikit-learn:main Dec 11, 2023
@betatim
Copy link
Member

betatim commented Dec 12, 2023

Since the array api support is still experimental, I assume that we can expect our users to always use the latest version of the array-api-compat soft dependency. Therefore I assume that this kind of change is fine as long as all existing tests pass.

A late comment: the take function we removed here is in a private part of scikit-learn (sklearn.utils._array_api), doesn't this mean it isn't part of the usual promise of backwards compatibility/deprecation cycles? Therefore it is always ok to remove, rename, change, edit, etc things that live in this module.

Or was your comment more about the minimum version of array_api_compat we require?

@fcharras
Copy link
Contributor Author

The Array API is versionned. take is a more recent addition to the standard. Having take defined inside scikit-learn enable support for anterior version of the Array API and removing it remove the support for those older versions. But it's fine since our support for Array API is experimental, I think we can commit to only support the latest version for the Array API standard at the moment. Whenever we're out of experimental status we should have a support policy regarding the supported versions of the Array API.

glemaitre pushed a commit to glemaitre/scikit-learn that referenced this pull request Feb 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants