From 7dcb38007dd399f657d346e967fbe5dec20e9c7d Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 12 Feb 2019 15:28:15 +0100 Subject: [PATCH 1/2] move _e_step after set_params & add test --- sklearn/mixture/base.py | 10 +++++----- sklearn/mixture/tests/test_gaussian_mixture.py | 9 +++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/sklearn/mixture/base.py b/sklearn/mixture/base.py index bd34333c0630b..8920bef181226 100644 --- a/sklearn/mixture/base.py +++ b/sklearn/mixture/base.py @@ -257,11 +257,6 @@ def fit_predict(self, X, y=None): best_params = self._get_parameters() best_n_iter = n_iter - # Always do a final e-step to guarantee that the labels returned by - # fit_predict(X) are always consistent with fit(X).predict(X) - # for any value of max_iter and tol (and any random_state). - _, log_resp = self._e_step(X) - if not self.converged_: warnings.warn('Initialization %d did not converge. ' 'Try different init parameters, ' @@ -273,6 +268,11 @@ def fit_predict(self, X, y=None): self.n_iter_ = best_n_iter self.lower_bound_ = max_lower_bound + # Always do a final e-step to guarantee that the labels returned by + # fit_predict(X) are always consistent with fit(X).predict(X) + # for any value of max_iter and tol (and any random_state). + _, log_resp = self._e_step(X) + return log_resp.argmax(axis=1) def _e_step(self, X): diff --git a/sklearn/mixture/tests/test_gaussian_mixture.py b/sklearn/mixture/tests/test_gaussian_mixture.py index 4d549ccd7b9d1..3559ecf9ba0c4 100644 --- a/sklearn/mixture/tests/test_gaussian_mixture.py +++ b/sklearn/mixture/tests/test_gaussian_mixture.py @@ -598,6 +598,15 @@ def test_gaussian_mixture_fit_predict(seed, max_iter, tol): assert_greater(adjusted_rand_score(Y, Y_pred2), .95) +def test_gaussian_mixture_fit_predict_n_init(): + # Check that fit_predict is equivalent to fit.predict, when n_init > 1 + X = np.random.RandomState(0).randn(1000, 5) + gm = GaussianMixture(n_components=5, n_init=5, random_state=0) + y_pred1 = gm.fit_predict(X) + y_pred2 = gm.predict(X) + assert_array_equal(y_pred1, y_pred2) + + def test_gaussian_mixture_fit(): # recover the ground truth rng = np.random.RandomState(0) From 032f39a78983e210988f17339e1552b37895c295 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 14 Feb 2019 10:13:32 +0100 Subject: [PATCH 2/2] add test for bayesian & what's new --- doc/whats_new/v0.21.rst | 9 +++++++++ sklearn/mixture/tests/test_bayesian_mixture.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 7355f75b83d4e..af0ee6df24974 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -291,6 +291,15 @@ Support for Python 3.4 and below has been officially dropped. affects all ensemble methods using decision trees. :issue:`12344` by :user:`Adrin Jalali `. +:mod:`sklearn.mixture` +...................... + +- |Fix| Fixed a bug in :class:`mixture.BaseMixture` and therefore on estimators + based on it, i.e. :class:`mixture.GaussianMixture` and + :class:`mixture.BayesianGaussianMixture`, where ``fit_predict`` and + ``fit.predict`` were not equivalent. :issue:`13142` by + :user:`Jérémie du Boisberranger `. + Multiple modules ................ diff --git a/sklearn/mixture/tests/test_bayesian_mixture.py b/sklearn/mixture/tests/test_bayesian_mixture.py index c3503a632238e..58df89aabe47f 100644 --- a/sklearn/mixture/tests/test_bayesian_mixture.py +++ b/sklearn/mixture/tests/test_bayesian_mixture.py @@ -451,6 +451,15 @@ def test_bayesian_mixture_fit_predict(seed, max_iter, tol): assert_array_equal(Y_pred1, Y_pred2) +def test_bayesian_mixture_fit_predict_n_init(): + # Check that fit_predict is equivalent to fit.predict, when n_init > 1 + X = np.random.RandomState(0).randn(1000, 5) + gm = BayesianGaussianMixture(n_components=5, n_init=10, random_state=0) + y_pred1 = gm.fit_predict(X) + y_pred2 = gm.predict(X) + assert_array_equal(y_pred1, y_pred2) + + def test_bayesian_mixture_predict_predict_proba(): # this is the same test as test_gaussian_mixture_predict_predict_proba() rng = np.random.RandomState(0)