Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG+1] add Convergence warning in LabelPropagation #5893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions sklearn/semi_supervised/label_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
>>> from sklearn.semi_supervised import LabelPropagation
>>> label_prop_model = LabelPropagation()
>>> iris = datasets.load_iris()
>>> random_unlabeled_points = np.where(np.random.randint(0, 2,
... size=len(iris.target)))
>>> rng = np.random.RandomState(42)
>>> random_unlabeled_points = rng.rand(len(iris.target)) < 0.3
>>> labels = np.copy(iris.target)
>>> labels[random_unlabeled_points] = -1
>>> label_prop_model.fit(iris.data, labels)
Expand All @@ -53,6 +53,7 @@
"""

# Authors: Clay Woolam <[email protected]>
# Utkarsh Upadhyay <[email protected]>
# License: BSD
from abc import ABCMeta, abstractmethod

Expand All @@ -67,13 +68,7 @@
from ..utils.extmath import safe_sparse_dot
from ..utils.multiclass import check_classification_targets
from ..utils.validation import check_X_y, check_is_fitted, check_array


# Helper functions

def _not_converged(y_truth, y_prediction, tol=1e-3):
"""basic convergence check"""
return np.abs(y_truth - y_prediction).sum() > tol
from ..exceptions import ConvergenceWarning


class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator,
Expand All @@ -97,7 +92,7 @@ class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator,
alpha : float
Clamping factor

max_iter : float
max_iter : integer
Change maximum number of iterations allowed

tol : float
Expand Down Expand Up @@ -264,12 +259,14 @@ def fit(self, X, y):

l_previous = np.zeros((self.X_.shape[0], n_classes))

remaining_iter = self.max_iter
unlabeled = unlabeled[:, np.newaxis]
if sparse.isspmatrix(graph_matrix):
graph_matrix = graph_matrix.tocsr()
while (_not_converged(self.label_distributions_, l_previous, self.tol)
and remaining_iter > 1):

for self.n_iter_ in range(self.max_iter):
if np.abs(self.label_distributions_ - l_previous).sum() < self.tol:
break

l_previous = self.label_distributions_
self.label_distributions_ = safe_sparse_dot(
graph_matrix, self.label_distributions_)
Expand All @@ -285,7 +282,12 @@ def fit(self, X, y):
# clamp
self.label_distributions_ = np.multiply(
alpha, self.label_distributions_) + y_static
remaining_iter -= 1
else:
warnings.warn(
'max_iter=%d was reached without convergence.' % self.max_iter,
category=ConvergenceWarning
)
self.n_iter_ += 1

normalizer = np.sum(self.label_distributions_, axis=1)[:, np.newaxis]
self.label_distributions_ /= normalizer
Expand All @@ -294,7 +296,6 @@ def fit(self, X, y):
transduction = self.classes_[np.argmax(self.label_distributions_,
axis=1)]
self.transduction_ = transduction.ravel()
self.n_iter_ = self.max_iter - remaining_iter
return self


Expand Down Expand Up @@ -324,7 +325,7 @@ class LabelPropagation(BaseLabelPropagation):
This parameter will be removed in 0.21.
'alpha' is fixed to zero in 'LabelPropagation'.

max_iter : float
max_iter : integer
Change maximum number of iterations allowed

tol : float
Expand Down Expand Up @@ -358,8 +359,8 @@ class LabelPropagation(BaseLabelPropagation):
>>> from sklearn.semi_supervised import LabelPropagation
>>> label_prop_model = LabelPropagation()
>>> iris = datasets.load_iris()
>>> random_unlabeled_points = np.where(np.random.randint(0, 2,
... size=len(iris.target)))
>>> rng = np.random.RandomState(42)
>>> random_unlabeled_points = rng.rand(len(iris.target)) < 0.3
>>> labels = np.copy(iris.target)
>>> labels[random_unlabeled_points] = -1
>>> label_prop_model.fit(iris.data, labels)
Expand Down Expand Up @@ -441,7 +442,7 @@ class LabelSpreading(BaseLabelPropagation):
alpha=0 means keeping the initial label information; alpha=1 means
replacing all initial information.

max_iter : float
max_iter : integer
maximum number of iterations allowed

tol : float
Expand Down Expand Up @@ -475,8 +476,8 @@ class LabelSpreading(BaseLabelPropagation):
>>> from sklearn.semi_supervised import LabelSpreading
>>> label_prop_model = LabelSpreading()
>>> iris = datasets.load_iris()
>>> random_unlabeled_points = np.where(np.random.randint(0, 2,
... size=len(iris.target)))
>>> rng = np.random.RandomState(42)
>>> random_unlabeled_points = rng.rand(len(iris.target)) < 0.3
>>> labels = np.copy(iris.target)
>>> labels[random_unlabeled_points] = -1
>>> label_prop_model.fit(iris.data, labels)
Expand Down
25 changes: 23 additions & 2 deletions sklearn/semi_supervised/tests/test_label_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.semi_supervised import label_propagation
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.datasets import make_classification
from sklearn.exceptions import ConvergenceWarning
from numpy.testing import assert_array_almost_equal
from numpy.testing import assert_array_equal

Expand Down Expand Up @@ -70,7 +71,7 @@ def test_alpha_deprecation():
y[::3] = -1

lp_default = label_propagation.LabelPropagation(kernel='rbf', gamma=0.1)
lp_default_y = assert_no_warnings(lp_default.fit, X, y).transduction_
lp_default_y = lp_default.fit(X, y).transduction_

lp_0 = label_propagation.LabelPropagation(alpha=0, kernel='rbf', gamma=0.1)
lp_0_y = assert_warns(DeprecationWarning, lp_0.fit, X, y).transduction_
Expand Down Expand Up @@ -108,7 +109,8 @@ def test_label_propagation_closed_form():
labelled_idx = (Y[:, (-1,)] == 0).nonzero()[0]

clf = label_propagation.LabelPropagation(max_iter=10000,
gamma=0.1).fit(X, y)
gamma=0.1)
clf.fit(X, y)
# adopting notation from Zhu et al 2002
T_bar = clf._build_graph()
Tuu = T_bar[np.meshgrid(unlabelled_idx, unlabelled_idx, indexing='ij')]
Expand Down Expand Up @@ -145,3 +147,22 @@ def test_convergence_speed():
# this should converge quickly:
assert mdl.n_iter_ < 10
assert_array_equal(mdl.predict(X), [0, 1, 1])


def test_convergence_warning():
# This is a non-regression test for #5774
X = np.array([[1., 0.], [0., 1.], [1., 2.5]])
y = np.array([0, 1, -1])
mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=1)
assert_warns(ConvergenceWarning, mdl.fit, X, y)
assert_equal(mdl.n_iter_, mdl.max_iter)

mdl = label_propagation.LabelPropagation(kernel='rbf', max_iter=1)
assert_warns(ConvergenceWarning, mdl.fit, X, y)
assert_equal(mdl.n_iter_, mdl.max_iter)

mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=500)
assert_no_warnings(mdl.fit, X, y)

mdl = label_propagation.LabelPropagation(kernel='rbf', max_iter=500)
assert_no_warnings(mdl.fit, X, y)