Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 97c8086

Browse files
ogriseladrinjalali
authored andcommitted
MNT improve the convergence warning message for LogisticRegression (scikit-learn#15665)
1 parent 56032ce commit 97c8086

File tree

3 files changed

+32
-13
lines changed

3 files changed

+32
-13
lines changed

sklearn/linear_model/_logistic.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@
3838
from ..metrics import get_scorer
3939

4040

41+
_LOGISTIC_SOLVER_CONVERGENCE_MSG = (
42+
"Please also refer to the documentation for alternative solver options:\n"
43+
" https://scikit-learn.org/stable/modules/linear_model.html"
44+
"#logistic-regression")
45+
46+
4147
# .. some helper functions for logistic_regression_path ..
4248
def _intercept_dot(w, X, y):
4349
"""Computes y * np.dot(X, w).
@@ -928,7 +934,9 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
928934
args=(X, target, 1. / C, sample_weight),
929935
options={"iprint": iprint, "gtol": tol, "maxiter": max_iter}
930936
)
931-
n_iter_i = _check_optimize_result(solver, opt_res, max_iter)
937+
n_iter_i = _check_optimize_result(
938+
solver, opt_res, max_iter,
939+
extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)
932940
w0, loss = opt_res.x, opt_res.fun
933941
elif solver == 'newton-cg':
934942
args = (X, target, 1. / C, sample_weight)

sklearn/linear_model/tests/test_logistic.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from sklearn.utils._testing import skip_if_no_parallel
3232

3333
from sklearn.exceptions import ConvergenceWarning
34-
from sklearn.exceptions import ChangedBehaviorWarning
3534
from sklearn.linear_model._logistic import (
3635
LogisticRegression,
3736
logistic_regression_path,
@@ -391,13 +390,20 @@ def test_logistic_regression_path_convergence_fail():
391390
y = [1] * 100 + [-1] * 100
392391
Cs = [1e3]
393392

394-
msg = (r"lbfgs failed to converge.+Increase the number of iterations or "
395-
r"scale the data")
396-
397-
with pytest.warns(ConvergenceWarning, match=msg):
393+
# Check that the convergence message points to both a model agnostic
394+
# advice (scaling the data) and to the logistic regression specific
395+
# documentation that includes hints on the solver configuration.
396+
with pytest.warns(ConvergenceWarning) as record:
398397
_logistic_regression_path(
399398
X, y, Cs=Cs, tol=0., max_iter=1, random_state=0, verbose=0)
400399

400+
assert len(record) == 1
401+
warn_msg = record[0].message.args[0]
402+
assert "lbfgs failed to converge" in warn_msg
403+
assert "Increase the number of iterations" in warn_msg
404+
assert "scale the data" in warn_msg
405+
assert "linear_model.html#logistic-regression" in warn_msg
406+
401407

402408
def test_liblinear_dual_random_state():
403409
# random_state is relevant for liblinear solver only if dual=True

sklearn/utils/optimize.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def _newton_cg(grad_hess, func, grad, x0, args=(), tol=1e-4,
213213
return xk, k
214214

215215

216-
def _check_optimize_result(solver, result, max_iter=None):
216+
def _check_optimize_result(solver, result, max_iter=None,
217+
extra_warning_msg=None):
217218
"""Check the OptimizeResult for successful convergence
218219
219220
Parameters
@@ -233,12 +234,16 @@ def _check_optimize_result(solver, result, max_iter=None):
233234
# handle both scipy and scikit-learn solver names
234235
if solver == "lbfgs":
235236
if result.status != 0:
236-
warnings.warn("{} failed to converge (status={}): {}. "
237-
"Increase the number of iterations or scale the "
238-
"data as shown in https://scikit-learn.org/stable/"
239-
"modules/preprocessing.html"
240-
.format(solver, result.status, result.message),
241-
ConvergenceWarning, stacklevel=2)
237+
warning_msg = (
238+
"{} failed to converge (status={}):\n{}.\n\n"
239+
"Increase the number of iterations (max_iter) "
240+
"or scale the data as shown in:\n"
241+
" https://scikit-learn.org/stable/modules/"
242+
"preprocessing.html."
243+
).format(solver, result.status, result.message.decode("latin1"))
244+
if extra_warning_msg is not None:
245+
warning_msg += "\n" + extra_warning_msg
246+
warnings.warn(warning_msg, ConvergenceWarning, stacklevel=2)
242247
if max_iter is not None:
243248
# In scipy <= 1.0.0, nit may exceed maxiter for lbfgs.
244249
# See https://github.com/scipy/scipy/issues/7854

0 commit comments

Comments
 (0)