Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ Metrics
Tools
-----

- :func:`model_selection.cross_val_predict`
- :func:`model_selection.train_test_split`
- :func:`utils.check_consistent_length`

Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/32270.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- :func:`sklearn.model_selection.cross_val_predict` now supports array API compatible inputs.
By :user:`Omar Salman <OmarManzoor>`
4 changes: 3 additions & 1 deletion sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,9 @@ def cross_val_predict(
concat_pred.append(label_preds)
predictions = concat_pred
else:
predictions = np.concatenate(predictions)
xp, _ = get_namespace(X)
inv_test_indices = xp.asarray(inv_test_indices, device=device(X))
predictions = xp.concat(predictions)

if isinstance(predictions, list):
return [p[inv_test_indices] for p in predictions]
Expand Down
51 changes: 50 additions & 1 deletion sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from scipy.sparse import issparse

from sklearn import config_context
from sklearn.base import BaseEstimator, ClassifierMixin, clone
from sklearn.base import BaseEstimator, ClassifierMixin, clone, is_classifier
from sklearn.cluster import KMeans
from sklearn.datasets import (
load_diabetes,
Expand All @@ -22,6 +22,7 @@
make_multilabel_classification,
make_regression,
)
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import FitFailedWarning, UnsetMetadataPassedError
from sklearn.impute import SimpleImputer
Expand Down Expand Up @@ -81,8 +82,15 @@
check_recorded_metadata,
)
from sklearn.utils import shuffle
from sklearn.utils._array_api import (
_atol_for_type,
_convert_to_numpy,
_get_namespace_device_dtype_ids,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._mocking import CheckingClassifier, MockDataFrame
from sklearn.utils._testing import (
_array_api_for_tests,
assert_allclose,
assert_almost_equal,
assert_array_almost_equal,
Expand Down Expand Up @@ -2725,3 +2733,44 @@ def test_learning_curve_exploit_incremental_learning_routing():

# End of metadata routing tests
# =============================


@pytest.mark.parametrize(
"estimator",
[Ridge(), LinearDiscriminantAnalysis()],
ids=["Ridge", "LinearDiscriminantAnalysis"],
)
@pytest.mark.parametrize("cv", [None, 3, 5])
@pytest.mark.parametrize(
"namespace, device_, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
def test_cross_val_predict_array_api_compliance(
estimator, cv, namespace, device_, dtype_name
):
"""Test that `cross_val_predict` functions correctly with the array API
with both a classifier and a regressor."""

xp = _array_api_for_tests(namespace, device_)
if is_classifier(estimator):
X, y = make_classification(
n_samples=1000, n_features=5, n_classes=3, n_informative=3, random_state=42
)
else:
X, y = make_regression(
n_samples=1000, n_features=5, n_informative=3, random_state=42
)

X_np = X.astype(dtype_name)
y_np = y.astype(dtype_name)
X_xp = xp.asarray(X_np, device=device_)
y_xp = xp.asarray(y_np, device=device_)

with config_context(array_api_dispatch=True):
pred_xp = cross_val_predict(estimator, X_xp, y_xp, cv=cv)

pred_np = cross_val_predict(estimator, X_np, y_np, cv=cv)
assert_allclose(
_convert_to_numpy(pred_xp, xp), pred_np, atol=_atol_for_type(dtype_name)
)
7 changes: 6 additions & 1 deletion sklearn/utils/_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import numpy as np
from scipy.sparse import issparse

from sklearn.utils._array_api import _is_numpy_namespace, get_namespace
from sklearn.utils._array_api import (
_is_numpy_namespace,
ensure_common_namespace_device,
get_namespace,
)
from sklearn.utils._param_validation import Interval, validate_params
from sklearn.utils.extmath import _approximate_mode
from sklearn.utils.fixes import PYARROW_VERSION_BELOW_17
Expand All @@ -31,6 +35,7 @@ def _array_indexing(array, key, key_dtype, axis):
"""Index an array or scipy.sparse consistently across NumPy version."""
xp, is_array_api = get_namespace(array)
if is_array_api:
key = ensure_common_namespace_device(array, key)[0]
return xp.take(array, key, axis=axis)
if issparse(array) and key_dtype == "bool":
key = np.asarray(key)
Expand Down
6 changes: 5 additions & 1 deletion sklearn/utils/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,11 @@ def _raise_or_return():
if xp.isdtype(y.dtype, "real floating"):
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.]
data = y.data if issparse(y) else y
if xp.any(data != xp.astype(data, int)):
integral_data = xp.astype(data, xp.int64)
# conversion back to the original float dtype of y is required to
# satisfy array-api-strict which does not allow a comparison between
# arrays having different dtypes.
if xp.any(data != xp.astype(integral_data, y.dtype)):
_assert_all_finite(data, input_name=input_name)
return "continuous" + suffix

Expand Down
Loading