Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG+1] add Convergence warning in LabelPropagation #5893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

musically-ut
Copy link
Contributor

Otherwise, it remains unclear whether the convergence was reached or whether the algorithm ran out of iterations. Currently, all the test cases trigger this warning. That is what triggered the investigation which led to #5774.

@TomDLT TomDLT changed the title EHN: Show a Convergence warning if the max_iters were performed. [MRG] add Convergence warning in LabelPropagation Nov 26, 2015
@TomDLT
Copy link
Member

TomDLT commented Nov 26, 2015

LGTM

Can you adapt the tests in order not to raise any convergence warning (or to silence them with ignore_warnings if necessary)?

@musically-ut
Copy link
Contributor Author

Will do.

@musically-ut
Copy link
Contributor Author

I can add with ignore_warnings(): in the tests, but it will still trigger the warnings in the doc-tests. I think I'll work on a more comprehensive solution which addresses the bug in the algorithm #5774 and this convergence issue together.

@musically-ut
Copy link
Contributor Author

I will re-base this after #9239 is merged in.

@musically-ut musically-ut force-pushed the feat-warn-label-prop-convergence branch from 6005320 to 9260699 Compare July 4, 2017 22:29
@musically-ut
Copy link
Contributor Author

I've rebased this against the current master and have silenced the warnings in the doctests.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test

@@ -32,7 +32,7 @@
--------
>>> from sklearn import datasets
>>> from sklearn.semi_supervised import LabelPropagation
>>> label_prop_model = LabelPropagation()
>>> label_prop_model = LabelPropagation(max_iter=1000)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the default max_iter something we should by changing while we are making other backwards incompatible changes? Or is the current default reasonable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I don't know. The ConvergenceWarning is partly there to allow users to adjust max_iter depending on their problem.

@musically-ut
Copy link
Contributor Author

Added a test.

@jnothman
Copy link
Member

jnothman commented Jul 5, 2017 via email

@jnothman
Copy link
Member

jnothman commented Jul 5, 2017 via email

@jnothman
Copy link
Member

jnothman commented Jul 5, 2017

@musically-ut
Copy link
Contributor Author

Hmm. I'll look into it later today/tomorrow.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please also assert_no_warning in the convergence case?

@musically-ut
Copy link
Contributor Author

LabelPropagation and LabelSpreading do seem to have very different behavior when it comes to convergence. I've changed the default number of iterations for them (1000 for LabelPropagation and 30 for LabelSpreading). I've also limited the number of unlabelled entries in the Doctests. I am tempted to set the seed in the doctest to make sure that we don't run into accidental failures. What do you think?

I've added assert_no_warning to one of the tests where convergence was being tested. Should I add it to others as well?

@jnothman
Copy link
Member

jnothman commented Jul 20, 2017 via email

@musically-ut
Copy link
Contributor Author

sounds good

So was that a "yes" to both:

  • Adding a seed to the Doctests
  • Adding assert_no_warnings to all calls to mdl.fit in the code

?

@jnothman
Copy link
Member

jnothman commented Jul 20, 2017 via email

@musically-ut
Copy link
Contributor Author

I've set np.random.seed for the tests. I am not sure how much of state is kept while switching from one set of tests to another and whether my setting the seed globally may reduce the variance in the tests happening downstream. If that is an issue, I can create a test specific RandomState for each doctest.

Also, I've added assert_no_warnings on the tests which depend on convergence.

@jnothman
Copy link
Member

jnothman commented Jul 23, 2017

Don't set the seed globally. It is not safe for multi-threaded testing, if nothing else. But usually we'd just use an integer for random_state. Just use rng = np.random.RandomState(42) or something instead.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that you've put all the assert_no_warnings in there, I think they just add confusion. Sorry. If you want to test convergence, test it with n_iter_. If you want to test that no warnings are issued in convergence (and you should), add this assertion to test_convergence_warning.

@jnothman
Copy link
Member

Thanks

@musically-ut
Copy link
Contributor Author

Ready for review (not sure whether pushing commits to the branch sends out notifications on GitHub).

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we see commits as they come, but it is often helpful to confirm what they mean with a comment.

Sorry to add more work. You can make the change of defaults a separate PR to make sure we get it into 0.19.

Also, please add your name to the list of file authors.

@@ -376,7 +383,7 @@ class LabelPropagation(BaseLabelPropagation):
_variant = 'propagation'

def __init__(self, kernel='rbf', gamma=20, n_neighbors=7,
alpha=None, max_iter=30, tol=1e-3, n_jobs=1):
alpha=None, max_iter=1000, tol=1e-3, n_jobs=1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can only make this change if we include it in the 0.19 final release. Just to note...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I follow so I'll rephrase what I understood:

If we create a separate PR just with this change, it can get included in the 0.19 final release.

Did I understand the note correctly? If so, I'll create one presently.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes :)

@@ -287,6 +288,12 @@ def fit(self, X, y):
alpha, self.label_distributions_) + y_static
remaining_iter -= 1

if remaining_iter <= 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is right given the current implementation, but the current implementation is buggy in that it seems you can never obtain n_iter_==max_iter.

Please fix the convergence condition and inline the code for _not_converged. I would then write something like:

while self.n_iter_ < self.max_iter:
    if ... < tol:
        break
    Do stuff
    self.n_iter_ += 1
else:
    Warn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

musically-ut added a commit to musically-ut/scikit-learn that referenced this pull request Jul 24, 2017
LabelPropagation converges much slower than LabelSpreading. The default
of max_iter=30 works well for LabelSpreading but not for
LabelPropagation.

This was extracted from scikit-learn#5893.
ogrisel pushed a commit that referenced this pull request Jul 27, 2017
LabelPropagation converges much slower than LabelSpreading. The default
of max_iter=30 works well for LabelSpreading but not for
LabelPropagation.

This was extracted from #5893.
ogrisel pushed a commit that referenced this pull request Jul 27, 2017
LabelPropagation converges much slower than LabelSpreading. The default
of max_iter=30 works well for LabelSpreading but not for
LabelPropagation.

This was extracted from #5893.
@musically-ut musically-ut force-pushed the feat-warn-label-prop-convergence branch from dd9e8fb to 54a5f26 Compare July 28, 2017 07:22
Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

and remaining_iter > 1):

self.n_iter_ = 0
while self.n_iter_ < self.max_iter:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we could now do this as a for loop...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean setting self.n_iter_ using the loop variable?
I personally feel a bit uncomfortable using the loop variable outside the loop.

(* angsty feeling *)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So ... shall I make this change to expedite the merge?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think he meant

for n_iter in range(self.max_iter):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the for loop, there are two alternatives:

Alternative 1

Setting self.n_iter_ after the for loop. It would look a bit ugly because we'll need to figure out whether in the last iteration the loop the tolerance condition was met (then self.n_iter_ = n_iter) or the else: branch was reached (in which case self.n_iter_ = n_iter + 1):

self.n_iter_ = 0
for n_iter in range(self.max_iter):
    if converged:
        break
    # ...
else:
    # warn
    self.n_iter_ = 1 # Count the last iteration.

self.n_iter_ += n_iter

Alternate 2

Leaving self.n_iter_ += 1 inside the for-loop. In this case case the loop variable (i.e. n_iter) is not used anywhere.

self.n_iter_ = 0
for n_iter in range(self.max_iter):
    if converged:
        break
    # ...
    self.n_iter_ += 1
else:
    # warn

Were any of these versions what you (both of you) had in mind? Personally, I like the while loop and then Alternative 2. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amueller In the new proposal, self.n_iter_ can never be zero, and zero is a valid value technically. The handling of these corner cases is what makes the for loop slightly ugly.

I'm sorry, I am away from my laptop and am a bit constrained in my explanations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment was a cosmetic one that is not worth withholding merge for.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So ... is that a "go" for the while loop? :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever you think most readable. I'd prefer

for self.n_iter_ in ...:
   ...
else:
   ...

I think, but it matters very little

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, I didn't know that we could use self.n_iter_ in the for loop; today I learned. :)

Second, this is the best implementation I can think of which has the same behavior as the original while loop:

for self.n_iter_ in range(max_iter):
    # ...
else:
    # ...
    self.n_iter_ += 1

The self.n_iter_ += 1 in the else: clause is necessary to ensure that self.n_iter_ == self.max_iter is possible and that it correctly happens if the loop doesn't break out due to convergence. This was the issue we were out to fix originally. :)

With this implementation, I am quite happy with the for loop as well. I'll make this change. 👍

@jnothman jnothman changed the title [MRG] add Convergence warning in LabelPropagation [MRG+1] add Convergence warning in LabelPropagation Jul 29, 2017
@amueller
Copy link
Member

amueller commented Aug 1, 2017

Do we want this for 0.19?

@amueller
Copy link
Member

amueller commented Aug 1, 2017

LGTM, no strong opinion on the for loop thought it would be slightly nicer.

@jnothman
Copy link
Member

jnothman commented Aug 1, 2017

I'm happy for this to be merged. I don't consider it essential for 0.19, but we are newly encouraging people to use these estimators, so why not.

Also, add tests for verify that n_iter_ == max_iter if warning is
raised.
jnothman pushed a commit to jnothman/scikit-learn that referenced this pull request Aug 6, 2017
LabelPropagation converges much slower than LabelSpreading. The default
of max_iter=30 works well for LabelSpreading but not for
LabelPropagation.

This was extracted from scikit-learn#5893.
@musically-ut
Copy link
Contributor Author

Bump!

@jnothman jnothman merged commit 6d4ae1b into scikit-learn:master Aug 6, 2017
jnothman pushed a commit to jnothman/scikit-learn that referenced this pull request Aug 6, 2017
dmohns pushed a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
LabelPropagation converges much slower than LabelSpreading. The default
of max_iter=30 works well for LabelSpreading but not for
LabelPropagation.

This was extracted from scikit-learn#5893.
dmohns pushed a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
dmohns pushed a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
LabelPropagation converges much slower than LabelSpreading. The default
of max_iter=30 works well for LabelSpreading but not for
LabelPropagation.

This was extracted from scikit-learn#5893.
dmohns pushed a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
LabelPropagation converges much slower than LabelSpreading. The default
of max_iter=30 works well for LabelSpreading but not for
LabelPropagation.

This was extracted from scikit-learn#5893.
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
AishwaryaRK pushed a commit to AishwaryaRK/scikit-learn that referenced this pull request Aug 29, 2017
LabelPropagation converges much slower than LabelSpreading. The default
of max_iter=30 works well for LabelSpreading but not for
LabelPropagation.

This was extracted from scikit-learn#5893.
AishwaryaRK pushed a commit to AishwaryaRK/scikit-learn that referenced this pull request Aug 29, 2017
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
LabelPropagation converges much slower than LabelSpreading. The default
of max_iter=30 works well for LabelSpreading but not for
LabelPropagation.

This was extracted from scikit-learn#5893.
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
LabelPropagation converges much slower than LabelSpreading. The default
of max_iter=30 works well for LabelSpreading but not for
LabelPropagation.

This was extracted from scikit-learn#5893.
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants