Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG+1] Raise warning in scikit-learn/sklearn/linear_model/cd_fast.pyx for cases when the main loop exits without reaching the desired tolerance #11754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from Feb 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions sklearn/linear_model/cd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cimport cython
from cpython cimport bool
from cython cimport floating
import warnings
from ..exceptions import ConvergenceWarning

from ..utils._cython_blas cimport (_axpy, _dot, _asum, _ger, _gemv, _nrm2,
_copy, _scal)
Expand Down Expand Up @@ -246,6 +247,14 @@ def enet_coordinate_descent(floating[::1] w,
if gap < tol:
# return if we reached desired tolerance
break

else:
with gil:
warnings.warn("Objective did not converge."
" You might want to increase the number of iterations."
" Duality gap: {}, tolerance: {}".format(gap, tol),
ConvergenceWarning)

return w, gap, tol, n_iter + 1


Expand Down Expand Up @@ -456,6 +465,13 @@ def sparse_enet_coordinate_descent(floating [::1] w,
# return if we reached desired tolerance
break

else:
with gil:
warnings.warn("Objective did not converge."
" You might want to increase the number of iterations."
" Duality gap: {}, tolerance: {}".format(gap, tol),
ConvergenceWarning)

return w, gap, tol, n_iter + 1


Expand Down Expand Up @@ -604,6 +620,13 @@ def enet_coordinate_descent_gram(floating[::1] w,
# return if we reached desired tolerance
break


with gil:
warnings.warn("Objective did not converge."
" You might want to increase the number of iterations."
" Duality gap: {}, tolerance: {}".format(gap, tol),
ConvergenceWarning)

return np.asarray(w), gap, tol, n_iter + 1


Expand Down Expand Up @@ -794,5 +817,11 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
if gap < tol:
# return if we reached desired tolerance
break
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after this warning message was printed, the for-loop goes on if max_iter is not reached, right?
And if max_iter is reached before the condition in 767 happens then it won't converge but never warn?

with gil:
warnings.warn("Objective did not converge."
" You might want to increase the number of iterations."
" Duality gap: {}, tolerance: {}".format(gap, tol),
ConvergenceWarning)

return np.asarray(W), gap, tol, n_iter + 1
13 changes: 0 additions & 13 deletions sklearn/linear_model/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ..utils.fixes import _joblib_parallel_args
from ..utils.validation import check_is_fitted
from ..utils.validation import column_or_1d
from ..exceptions import ConvergenceWarning

from . import cd_fast

Expand Down Expand Up @@ -481,13 +480,6 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
coefs[..., i] = coef_
dual_gaps[i] = dual_gap_
n_iters.append(n_iter_)
if dual_gap_ > eps_:
warnings.warn('Objective did not converge.' +
' You might want' +
' to increase the number of iterations.' +
' Fitting data with very small alpha' +
' may cause precision problems.',
ConvergenceWarning)

if verbose:
if verbose > 2:
Expand Down Expand Up @@ -1812,11 +1804,6 @@ def fit(self, X, y):

self._set_intercept(X_offset, y_offset, X_scale)

if self.dual_gap_ > self.eps_:
warnings.warn('Objective did not converge, you might want'
' to increase the number of iterations',
ConvergenceWarning)

# return self for chaining fit and predict calls
return self

Expand Down
17 changes: 17 additions & 0 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,3 +828,20 @@ def test_warm_start_multitask_lasso():
clf2 = MultiTaskLasso(alpha=0.1, max_iter=10)
ignore_warnings(clf2.fit)(X, Y)
assert_array_almost_equal(clf2.coef_, clf.coef_)


@pytest.mark.parametrize('klass, n_classes, kwargs',
[(Lasso, 1, dict(precompute=True)),
(Lasso, 1, dict(precompute=False)),
(MultiTaskLasso, 2, dict()),
(MultiTaskLasso, 2, dict())])
def test_enet_coordinate_descent(klass, n_classes, kwargs):
"""Test that a warning is issued if model does not converge"""
clf = klass(max_iter=2, **kwargs)
n_samples = 5
n_features = 2
X = np.ones((n_samples, n_features)) * 1e50
y = np.ones((n_samples, n_classes))
if klass == Lasso:
y = y.ravel()
assert_warns(ConvergenceWarning, clf.fit, X, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to test this use tiny data and set max_iter to a very tiny number. it will make testing faster.

besides this is already done in the estimator eg:

https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/linear_model/coordinate_descent.py#L486

why is it not enough for you? bug?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the cases identified in #10813, the estimator raises no warning. Essentially, there are numerical issues causing the duality gap and tolerance to be equal to zero - as such the warning won't be raised in the estimator.

12 changes: 12 additions & 0 deletions sklearn/linear_model/tests/test_sparse_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import ignore_warnings
from sklearn.utils.testing import assert_warns
from sklearn.exceptions import ConvergenceWarning

from sklearn.linear_model.coordinate_descent import (Lasso, ElasticNet,
LassoCV, ElasticNetCV)
Expand Down Expand Up @@ -290,3 +292,13 @@ def test_same_multiple_output_sparse_dense():
predict_sparse = l_sp.predict(sample_sparse)

assert_array_almost_equal(predict_sparse, predict_dense)


def test_sparse_enet_coordinate_descent():
"""Test that a warning is issued if model does not converge"""
clf = Lasso(max_iter=2)
n_samples = 5
n_features = 2
X = sp.csc_matrix((n_samples, n_features)) * 1e50
y = np.ones(n_samples)
assert_warns(ConvergenceWarning, clf.fit, X, y)