diff --git a/sklearn/semi_supervised/label_propagation.py b/sklearn/semi_supervised/label_propagation.py index c690ac1f151f4..10eebba86f04e 100644 --- a/sklearn/semi_supervised/label_propagation.py +++ b/sklearn/semi_supervised/label_propagation.py @@ -34,8 +34,8 @@ >>> from sklearn.semi_supervised import LabelPropagation >>> label_prop_model = LabelPropagation() >>> iris = datasets.load_iris() ->>> random_unlabeled_points = np.where(np.random.randint(0, 2, -... size=len(iris.target))) +>>> rng = np.random.RandomState(42) +>>> random_unlabeled_points = rng.rand(len(iris.target)) < 0.3 >>> labels = np.copy(iris.target) >>> labels[random_unlabeled_points] = -1 >>> label_prop_model.fit(iris.data, labels) @@ -53,6 +53,7 @@ """ # Authors: Clay Woolam +# Utkarsh Upadhyay # License: BSD from abc import ABCMeta, abstractmethod @@ -67,13 +68,7 @@ from ..utils.extmath import safe_sparse_dot from ..utils.multiclass import check_classification_targets from ..utils.validation import check_X_y, check_is_fitted, check_array - - -# Helper functions - -def _not_converged(y_truth, y_prediction, tol=1e-3): - """basic convergence check""" - return np.abs(y_truth - y_prediction).sum() > tol +from ..exceptions import ConvergenceWarning class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator, @@ -97,7 +92,7 @@ class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator, alpha : float Clamping factor - max_iter : float + max_iter : integer Change maximum number of iterations allowed tol : float @@ -264,12 +259,14 @@ def fit(self, X, y): l_previous = np.zeros((self.X_.shape[0], n_classes)) - remaining_iter = self.max_iter unlabeled = unlabeled[:, np.newaxis] if sparse.isspmatrix(graph_matrix): graph_matrix = graph_matrix.tocsr() - while (_not_converged(self.label_distributions_, l_previous, self.tol) - and remaining_iter > 1): + + for self.n_iter_ in range(self.max_iter): + if np.abs(self.label_distributions_ - l_previous).sum() < self.tol: + break + l_previous = self.label_distributions_ self.label_distributions_ = safe_sparse_dot( graph_matrix, self.label_distributions_) @@ -285,7 +282,12 @@ def fit(self, X, y): # clamp self.label_distributions_ = np.multiply( alpha, self.label_distributions_) + y_static - remaining_iter -= 1 + else: + warnings.warn( + 'max_iter=%d was reached without convergence.' % self.max_iter, + category=ConvergenceWarning + ) + self.n_iter_ += 1 normalizer = np.sum(self.label_distributions_, axis=1)[:, np.newaxis] self.label_distributions_ /= normalizer @@ -294,7 +296,6 @@ def fit(self, X, y): transduction = self.classes_[np.argmax(self.label_distributions_, axis=1)] self.transduction_ = transduction.ravel() - self.n_iter_ = self.max_iter - remaining_iter return self @@ -324,7 +325,7 @@ class LabelPropagation(BaseLabelPropagation): This parameter will be removed in 0.21. 'alpha' is fixed to zero in 'LabelPropagation'. - max_iter : float + max_iter : integer Change maximum number of iterations allowed tol : float @@ -358,8 +359,8 @@ class LabelPropagation(BaseLabelPropagation): >>> from sklearn.semi_supervised import LabelPropagation >>> label_prop_model = LabelPropagation() >>> iris = datasets.load_iris() - >>> random_unlabeled_points = np.where(np.random.randint(0, 2, - ... size=len(iris.target))) + >>> rng = np.random.RandomState(42) + >>> random_unlabeled_points = rng.rand(len(iris.target)) < 0.3 >>> labels = np.copy(iris.target) >>> labels[random_unlabeled_points] = -1 >>> label_prop_model.fit(iris.data, labels) @@ -441,7 +442,7 @@ class LabelSpreading(BaseLabelPropagation): alpha=0 means keeping the initial label information; alpha=1 means replacing all initial information. - max_iter : float + max_iter : integer maximum number of iterations allowed tol : float @@ -475,8 +476,8 @@ class LabelSpreading(BaseLabelPropagation): >>> from sklearn.semi_supervised import LabelSpreading >>> label_prop_model = LabelSpreading() >>> iris = datasets.load_iris() - >>> random_unlabeled_points = np.where(np.random.randint(0, 2, - ... size=len(iris.target))) + >>> rng = np.random.RandomState(42) + >>> random_unlabeled_points = rng.rand(len(iris.target)) < 0.3 >>> labels = np.copy(iris.target) >>> labels[random_unlabeled_points] = -1 >>> label_prop_model.fit(iris.data, labels) diff --git a/sklearn/semi_supervised/tests/test_label_propagation.py b/sklearn/semi_supervised/tests/test_label_propagation.py index 3d5bd21a89110..8cd0cce41d7e9 100644 --- a/sklearn/semi_supervised/tests/test_label_propagation.py +++ b/sklearn/semi_supervised/tests/test_label_propagation.py @@ -9,6 +9,7 @@ from sklearn.semi_supervised import label_propagation from sklearn.metrics.pairwise import rbf_kernel from sklearn.datasets import make_classification +from sklearn.exceptions import ConvergenceWarning from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_equal @@ -70,7 +71,7 @@ def test_alpha_deprecation(): y[::3] = -1 lp_default = label_propagation.LabelPropagation(kernel='rbf', gamma=0.1) - lp_default_y = assert_no_warnings(lp_default.fit, X, y).transduction_ + lp_default_y = lp_default.fit(X, y).transduction_ lp_0 = label_propagation.LabelPropagation(alpha=0, kernel='rbf', gamma=0.1) lp_0_y = assert_warns(DeprecationWarning, lp_0.fit, X, y).transduction_ @@ -108,7 +109,8 @@ def test_label_propagation_closed_form(): labelled_idx = (Y[:, (-1,)] == 0).nonzero()[0] clf = label_propagation.LabelPropagation(max_iter=10000, - gamma=0.1).fit(X, y) + gamma=0.1) + clf.fit(X, y) # adopting notation from Zhu et al 2002 T_bar = clf._build_graph() Tuu = T_bar[np.meshgrid(unlabelled_idx, unlabelled_idx, indexing='ij')] @@ -145,3 +147,22 @@ def test_convergence_speed(): # this should converge quickly: assert mdl.n_iter_ < 10 assert_array_equal(mdl.predict(X), [0, 1, 1]) + + +def test_convergence_warning(): + # This is a non-regression test for #5774 + X = np.array([[1., 0.], [0., 1.], [1., 2.5]]) + y = np.array([0, 1, -1]) + mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=1) + assert_warns(ConvergenceWarning, mdl.fit, X, y) + assert_equal(mdl.n_iter_, mdl.max_iter) + + mdl = label_propagation.LabelPropagation(kernel='rbf', max_iter=1) + assert_warns(ConvergenceWarning, mdl.fit, X, y) + assert_equal(mdl.n_iter_, mdl.max_iter) + + mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=500) + assert_no_warnings(mdl.fit, X, y) + + mdl = label_propagation.LabelPropagation(kernel='rbf', max_iter=500) + assert_no_warnings(mdl.fit, X, y)