diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index 2fb57a64118f7..b9fecad5840ff 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -147,6 +147,7 @@ Tools ----- - :func:`model_selection.train_test_split` +- :func:`utils.check_consistent_length` Coverage is expected to grow over time. Please follow the dedicated `meta-issue on GitHub `_ to track progress. diff --git a/doc/whats_new/upcoming_changes/array-api/29519.feature.rst b/doc/whats_new/upcoming_changes/array-api/29519.feature.rst new file mode 100644 index 0000000000000..19f800ee45b4b --- /dev/null +++ b/doc/whats_new/upcoming_changes/array-api/29519.feature.rst @@ -0,0 +1,3 @@ +- :func:`sklearn.utils.check_consistent_length` now supports Array API compatible + inputs. + By :user:`Stefanie Senger ` diff --git a/sklearn/metrics/cluster/_supervised.py b/sklearn/metrics/cluster/_supervised.py index 7e001cf72c72b..e9ee22056cb5e 100644 --- a/sklearn/metrics/cluster/_supervised.py +++ b/sklearn/metrics/cluster/_supervised.py @@ -1184,7 +1184,7 @@ def fowlkes_mallows_score(labels_true, labels_pred, *, sparse=False): .. versionadded:: 0.18 - The Fowlkes-Mallows index (FMI) is defined as the geometric mean between of + The Fowlkes-Mallows index (FMI) is defined as the geometric mean of the precision and recall:: FMI = TP / sqrt((TP + FP) * (TP + FN)) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index e380a2311355e..b2b4f88fa218f 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -536,10 +536,11 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): ------- namespace : module Namespace shared by array objects. If any of the `arrays` are not arrays, - the namespace defaults to NumPy. + the namespace defaults to the NumPy namespace. is_array_api_compliant : bool - True if the arrays are containers that implement the Array API spec. + True if the arrays are containers that implement the array API spec (see + https://data-apis.org/array-api/latest/index.html). Always False when array_api_dispatch=False. """ array_api_dispatch = get_config()["array_api_dispatch"] diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 669e40e137e17..d599e3d0784fe 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -34,6 +34,7 @@ check_X_y, deprecated, ) +from sklearn.utils._array_api import yield_namespace_device_dtype_combinations from sklearn.utils._mocking import ( MockDataFrame, _MockEstimatorOnOffPrediction, @@ -41,6 +42,7 @@ from sklearn.utils._testing import ( SkipTest, TempMemmap, + _array_api_for_tests, _convert_container, assert_allclose, assert_allclose_dense_sparse, @@ -1002,6 +1004,8 @@ def test_check_is_fitted_with_attributes(wrap): def test_check_consistent_length(): + """Test that `check_consistent_length` raises on inconsistent lengths and wrong + input types trigger TypeErrors.""" check_consistent_length([1], [2], [3], [4], [5]) check_consistent_length([[1, 2], [[1, 2]]], [1, 2], ["a", "b"]) check_consistent_length([1], (2,), np.array([3]), sp.csr_matrix((1, 2))) @@ -1011,16 +1015,37 @@ def test_check_consistent_length(): check_consistent_length([1, 2], 1) with pytest.raises(TypeError, match=r"got <\w+ 'object'>"): check_consistent_length([1, 2], object()) - with pytest.raises(TypeError): check_consistent_length([1, 2], np.array(1)) - # Despite ensembles having __len__ they must raise TypeError with pytest.raises(TypeError, match="Expected sequence or array-like"): check_consistent_length([1, 2], RandomForestRegressor()) # XXX: We should have a test with a string, but what is correct behaviour? +@pytest.mark.parametrize( + "array_namespace, device, _", yield_namespace_device_dtype_combinations() +) +def test_check_consistent_length_array_api(array_namespace, device, _): + """Test that check_consistent_length works with different array types.""" + xp = _array_api_for_tests(array_namespace, device) + + with config_context(array_api_dispatch=True): + check_consistent_length( + xp.asarray([1, 2, 3], device=device), + xp.asarray([[1, 1], [2, 2], [3, 3]], device=device), + [1, 2, 3], + ["a", "b", "c"], + np.asarray(("a", "b", "c"), dtype=object), + sp.csr_array([[0, 1], [1, 0], [0, 0]]), + ) + + with pytest.raises(ValueError, match="inconsistent numbers of samples"): + check_consistent_length( + xp.asarray([1, 2], device=device), xp.asarray([1], device=device) + ) + + def test_check_dataframe_fit_attribute(): # check pandas dataframe with 'fit' column does not raise error # https://github.com/scikit-learn/scikit-learn/issues/8415 diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index ca7c968852975..1332dcdcca69b 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -467,10 +467,8 @@ def check_consistent_length(*arrays): >>> b = [2, 3, 4] >>> check_consistent_length(a, b) """ - lengths = [_num_samples(X) for X in arrays if X is not None] - uniques = np.unique(lengths) - if len(uniques) > 1: + if len(set(lengths)) > 1: raise ValueError( "Found input variables with inconsistent numbers of samples: %r" % [int(l) for l in lengths]