diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 73b87d260acba..f048e7f955995 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -24,6 +24,13 @@ random sampling procedures. and has been fixed. :pr:`26416` by :user:`Yang Tao `. +- |Fix| Ridge models with `solver='sparse_cg'` may have slightly different + results with scipy>=1.12, because of an underlying change in the scipy solver + (see `scipy#18488 `_ for more + details) + :pr:`26814` by :user:`Loïc Estève ` + + Changes impacting all modules ----------------------------- diff --git a/sklearn/linear_model/_ridge.py b/sklearn/linear_model/_ridge.py index 63c02185fe4a1..0258a379b8852 100644 --- a/sklearn/linear_model/_ridge.py +++ b/sklearn/linear_model/_ridge.py @@ -33,6 +33,7 @@ ) from ..utils._param_validation import Interval, StrOptions, validate_params from ..utils.extmath import row_norms, safe_sparse_dot +from ..utils.fixes import _sparse_linalg_cg from ..utils.sparsefuncs import mean_variance_axis from ..utils.validation import _check_sample_weight, check_is_fitted from ._base import LinearClassifierMixin, LinearModel, _preprocess_data, _rescale_data @@ -105,7 +106,7 @@ def _mv(x): C = sp_linalg.LinearOperator( (n_samples, n_samples), matvec=mv, dtype=X.dtype ) - coef, info = sp_linalg.cg(C, y_column, tol=tol, atol="legacy") + coef, info = _sparse_linalg_cg(C, y_column, rtol=tol) coefs[i] = X1.rmatvec(coef) else: # linear ridge @@ -114,9 +115,7 @@ def _mv(x): C = sp_linalg.LinearOperator( (n_features, n_features), matvec=mv, dtype=X.dtype ) - coefs[i], info = sp_linalg.cg( - C, y_column, maxiter=max_iter, tol=tol, atol="legacy" - ) + coefs[i], info = _sparse_linalg_cg(C, y_column, maxiter=max_iter, rtol=tol) if info < 0: raise ValueError("Failed with error code %d" % info) diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index d33b638358157..ba5ce3c35d07f 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -15,6 +15,7 @@ import numpy as np import scipy +import scipy.sparse.linalg import scipy.stats import threadpoolctl @@ -109,6 +110,19 @@ def _mode(a, axis=0): return scipy.stats.mode(a, axis=axis) +# TODO: Remove when Scipy 1.12 is the minimum supported version +if sp_base_version >= parse_version("1.12.0"): + _sparse_linalg_cg = scipy.sparse.linalg.cg +else: + + def _sparse_linalg_cg(A, b, **kwargs): + if "rtol" in kwargs: + kwargs["tol"] = kwargs.pop("rtol") + if "atol" not in kwargs: + kwargs["atol"] = "legacy" + return scipy.sparse.linalg.cg(A, b, **kwargs) + + ############################################################################### # Backport of Python 3.9's importlib.resources # TODO: Remove when Python 3.9 is the minimum supported version