-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
[MRG + 1] #10336 adding fit_predict to mixture models #11281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sklearn/mixture/base.py
Outdated
| times until the change of likelihood or lower bound is less than | ||
| `tol`, otherwise, a `ConvergenceWarning` is raised. | ||
| `tol`, otherwise, a `ConvergenceWarning` is raised. After fitting, it | ||
| predicts the most probable label for the input data points. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this docstring seems not correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. I'm changing it back to original.
sklearn/mixture/base.py
Outdated
| """Estimate model parameters with the EM algorithm. | ||
| The method fit the model `n_init` times and set the parameters with | ||
| The method first fit the model `n_init` times and set the parameters with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should probably be "fits", right?
| assert_array_equal(Y_pred1, Y_pred2) | ||
|
|
||
|
|
||
| def test_bayesian_mixture_predict_predict_proba(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this was copied from the other test, can you maybe say so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get how this relates to the issue... What am I missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry just saw this! This was copied from other test. I will comment on that. We added this test because to test fit_predict, we intended to test two things: 1. it is equivalent to fit().predict(); 2. it's output is correct. There was no testing for correctness of predict() for bgmm, we added one so that we know fit_predict actually yields the correct output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, great!
|
looks good. Is that jet in your avatar? ;) |
jnothman
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM.
It looks like those tests could be refactored though
| assert_array_equal(Y_pred1, Y_pred2) | ||
|
|
||
|
|
||
| def test_bayesian_mixture_predict_predict_proba(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get how this relates to the issue... What am I missing?
| X = rand_data.X[covar_type] | ||
| Y = rand_data.Y | ||
| g = GaussianMixture(n_components=rand_data.n_components, | ||
| random_state=rng, weights_init=rand_data.weights, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be passing random_state=0 rather than passing an object which will be changed with each iteration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it doesn't matter because the only thing we need is that the GMM is not changed within one iteration. Using different random_state with different COVARIANCE_TYPE actually tests more robustness? I guess?
|
You have flake8 failures |
|
Just fixed some flake8 problems. There is one left at line 456 of file tests/test_bayesian_mixture.py which I do not know how to solve. |
| Y = rand_data.Y | ||
| bgmm = BayesianGaussianMixture(n_components=rand_data.n_components, | ||
| random_state=rng, | ||
| weight_concentration_prior_type=prior_type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could either leave this flake8 issue unsolved, or do:
bgmm = BayesianGaussianMixture(
n_components=rand_data.n_components,
random_state=rng,
weight_concentration_prior_type=prior_type,
covariance_type=covar_type)| assert_array_equal(Y_pred1, Y_pred2) | ||
|
|
||
|
|
||
| def test_bayesian_mixture_predict_predict_proba(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, great!
|
Please add an entry to the change log at I'm not sure if it's better listed under API changes or as an enhancement. |
| times until the change of likelihood or lower bound is less than | ||
| `tol`, otherwise, a `ConvergenceWarning` is raised. After fitting, it | ||
| predicts the most probable label for the input data points. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps this deserves .. versionadded:: 0.20
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed an email and just saw this! Will do soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am putting it under API changes, because there seems to be no enhancement in terms of efficiency or accuracy.
|
Thanks @haoranShu |
Reference Issues
Fixes #10336
What does this implement/fix? Explain your changes.
Added
fit_predictmethod to all Gaussian mixture models and added tests to Bayesian Gaussian Mixture and Gaussian Mixture.Any other comments?
fitis changed to callfit-predict, which really does the computation. In this way we can use log_resp conveniently for predict.