Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG+1] Raise warning in scikit-learn/sklearn/linear_model/cd_fast.pyx for cases when the main loop exits without reaching the desired tolerance #11754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from Feb 27, 2019

Conversation

ghost
Copy link

@ghost ghost commented Aug 5, 2018

Reference Issues/PRs

Fixes #10813.

What does this implement/fix? Explain your changes.

This pull request adds ConvergenceWarnings to the enet_coordinate_descent* solvers found in scikit-learn/sklearn/linear_model/cd_fast.pyx for cases when the main loop exits without reaching the desired tolerance.

Any other comments?

Tests have been included in both sklearn/linear_model/tests/test_coordinate_descent.py and sklearn/linear_model/tests/test_sparse_coordinate_descent.py

n_classes = 2
X = np.ones([n_samples, n_features]) * 1e50
y = np.ones([n_samples, n_classes])
assert_warns(ConvergenceWarning, clf.fit, X, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to test this use tiny data and set max_iter to a very tiny number. it will make testing faster.

besides this is already done in the estimator eg:

https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/linear_model/coordinate_descent.py#L486

why is it not enough for you? bug?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the cases identified in #10813, the estimator raises no warning. Essentially, there are numerical issues causing the duality gap and tolerance to be equal to zero - as such the warning won't be raised in the estimator.

else:
with gil:
warnings.warn("Objective did not converge."
" You might want to increase the number of iterations.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we include the desired and the achieved tolerance in the warning message?

clf = Lasso(precompute=True)
n_samples = 15500
n_features = 500
X = np.ones([n_samples, n_features]) * 1e50
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in np.ones(shape), typically shape is a tuple not a list.

@@ -302,6 +303,13 @@ def enet_coordinate_descent(np.ndarray[floating, ndim=1] w,
if gap < tol:
# return if we reached desired tolerance
break
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be an else clause of the for loop, not of the if statement?

Please add a test that no warning is raised if the optimisation reaches convergence

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the function returns gap, can't we do this outside of the function??

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue #10813 suggested raising the warning when the max number of iterations is reached and the desired tolerance has yet to be achieved. It's not obvious to me how to test for that outside of the function because it doesn't return max_iter.

Should this be changed to look for instances where gap and tolerance are both equal to zero - indicating possible numerical issues?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_iter is passed in, so surely it is available to the caller test

@jnothman
Copy link
Member

jnothman commented Aug 7, 2018

Please rename this PR to describe what it is actually changing. It is not doing what the title says

@ghost ghost changed the title [MRG] Fix Linear models take unreasonable longer time in certain data size. [MRG] Raise warning in scikit-learn/sklearn/linear_model/cd_fast.pyx for cases when the main loop exits without reaching the desired tolerance Aug 7, 2018
Copy link
Member

@rth rth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this check is already implemented in a factorized way for a number of enet_coordinate_descent* functions here, similarly to @jnothman's suggestion in https://github.com/scikit-learn/scikit-learn/pull/11754/files#r208071463.

What is missing is to find other places where these functions are used and implement a similar check, namely just,

sklearn/covariance/graph_lasso_.py
225:                        coefs, _, _, _ = cd_fast.enet_coordinate_descent_gram(

@ghost
Copy link
Author

ghost commented Nov 14, 2018

The strict inequality in those functions misses the cases raised in the original issue. Essentially, there are numerical issues causing the duality gap and tolerance to be equal to zero - as such the warning won't be raised in the estimator. Should the check in those functions be changed?

@agramfort
Copy link
Member

rebased. Now ConvergeWarning comes from the cython code for those who use directly the cython functions.

@agramfort
Copy link
Member

@rth can you have a look?

Copy link
Member

@GaelVaroquaux GaelVaroquaux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, 👍 for merge

But, a caveat: if this PR ends up cranking up the number of warnings (I've tried to check, but it's really hard to assess), I would push for reverting it.

@GaelVaroquaux GaelVaroquaux requested a review from rth February 26, 2019 17:01
@GaelVaroquaux GaelVaroquaux changed the title [MRG] Raise warning in scikit-learn/sklearn/linear_model/cd_fast.pyx for cases when the main loop exits without reaching the desired tolerance [MRG+1] Raise warning in scikit-learn/sklearn/linear_model/cd_fast.pyx for cases when the main loop exits without reaching the desired tolerance Feb 26, 2019
Copy link
Member

@rth rth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

@rth rth merged commit 3e715fd into scikit-learn:master Feb 27, 2019
@rth
Copy link
Member

rth commented Feb 27, 2019

@@ -794,5 +817,11 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
if gap < tol:
# return if we reached desired tolerance
break
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after this warning message was printed, the for-loop goes on if max_iter is not reached, right?
And if max_iter is reached before the condition in 767 happens then it won't converge but never warn?

xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
…r cases when the main loop exits without reaching the desired tolerance (scikit-learn#11754)
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
…t.pyx for cases when the main loop exits without reaching the desired tolerance (scikit-learn#11754)"

This reverts commit 97a1b0e.
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
…t.pyx for cases when the main loop exits without reaching the desired tolerance (scikit-learn#11754)"

This reverts commit 97a1b0e.
koenvandevelde pushed a commit to koenvandevelde/scikit-learn that referenced this pull request Jul 12, 2019
…r cases when the main loop exits without reaching the desired tolerance (scikit-learn#11754)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Linear models take unreasonable longer time in certain data size.
6 participants