diff --git a/sklearn/impute/_base.py b/sklearn/impute/_base.py index 1842a0527891a..225ae249b2107 100644 --- a/sklearn/impute/_base.py +++ b/sklearn/impute/_base.py @@ -9,10 +9,10 @@ import numpy as np import numpy.ma as ma from scipy import sparse as sp -from scipy import stats from ..base import BaseEstimator, TransformerMixin from ..utils._param_validation import StrOptions +from ..utils.fixes import _mode from ..utils.sparsefuncs import _get_median from ..utils.validation import check_is_fitted from ..utils.validation import FLOAT_DTYPES @@ -52,7 +52,7 @@ def _most_frequent(array, extra_value, n_repeat): if count == most_frequent_count ) else: - mode = stats.mode(array) + mode = _mode(array) most_frequent_value = mode[0][0] most_frequent_count = mode[1][0] else: diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index 25ee67728e1e7..eebd615b2491c 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -10,7 +10,7 @@ from numbers import Integral import numpy as np -from scipy import stats +from ..utils.fixes import _mode from ..utils.extmath import weighted_mode from ..utils.validation import _is_arraylike, _num_samples @@ -249,7 +249,7 @@ def predict(self, X): y_pred = np.empty((n_queries, n_outputs), dtype=classes_[0].dtype) for k, classes_k in enumerate(classes_): if weights is None: - mode, _ = stats.mode(_y[neigh_ind, k], axis=1) + mode, _ = _mode(_y[neigh_ind, k], axis=1) else: mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index b0074ae7e3a18..cdd63e00cd381 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -163,3 +163,10 @@ def threadpool_info(): threadpool_info.__doc__ = threadpoolctl.threadpool_info.__doc__ + + +# TODO: Remove when SciPy 1.9 is the minimum supported version +def _mode(a, axis=0): + if sp_version >= parse_version("1.9.0"): + return scipy.stats.mode(a, axis=axis, keepdims=True) + return scipy.stats.mode(a, axis=axis) diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 07a553c8cf09d..14e541bbef2dc 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -6,7 +6,6 @@ import numpy as np from scipy import sparse from scipy import linalg -from scipy import stats from scipy.sparse.linalg import eigsh from scipy.special import expit @@ -19,6 +18,7 @@ from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import skip_if_32bit +from sklearn.utils.fixes import _mode from sklearn.utils.extmath import density, _safe_accumulator_op from sklearn.utils.extmath import randomized_svd, _randomized_eigsh @@ -56,7 +56,7 @@ def test_uniform_weights(): weights = np.ones(x.shape) for axis in (None, 0, 1): - mode, score = stats.mode(x, axis) + mode, score = _mode(x, axis) mode2, score2 = weighted_mode(x, weights, axis=axis) assert_array_equal(mode, mode2)