-
-
Notifications
You must be signed in to change notification settings - Fork 26.4k
FEA Add metadata routing to GraphicalLassoCV #27566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9a187b0
351ecc4
305c2c4
7814dc7
c370992
dad5555
c796100
29faf26
10d3b32
aa64a17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| graphical_lasso, | ||
| ) | ||
| from sklearn.datasets import make_sparse_spd_matrix | ||
| from sklearn.model_selection import GroupKFold | ||
| from sklearn.utils import check_random_state | ||
| from sklearn.utils._testing import ( | ||
| _convert_container, | ||
|
|
@@ -254,12 +255,71 @@ def test_graphical_lasso_cv_scores(): | |
| X | ||
| ) | ||
|
|
||
| _assert_graphical_lasso_cv_scores( | ||
| cov=cov, | ||
| n_splits=splits, | ||
| n_refinements=n_refinements, | ||
| n_alphas=n_alphas, | ||
| ) | ||
|
|
||
|
|
||
| # TODO(1.5): remove in 1.5 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is already removed?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is removed in main yet. I think you see this change only because of the rearrangement otherwise this is not part of this PR. |
||
| def test_graphical_lasso_cov_init_deprecation(): | ||
| """Check that we raise a deprecation warning if providing `cov_init` in | ||
| `graphical_lasso`.""" | ||
| rng, dim, n_samples = np.random.RandomState(0), 20, 100 | ||
| prec = make_sparse_spd_matrix(dim, alpha=0.95, random_state=0) | ||
| cov = linalg.inv(prec) | ||
| X = rng.multivariate_normal(np.zeros(dim), cov, size=n_samples) | ||
|
|
||
| emp_cov = empirical_covariance(X) | ||
| with pytest.warns(FutureWarning, match="cov_init parameter is deprecated"): | ||
| graphical_lasso(emp_cov, alpha=0.1, cov_init=emp_cov) | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("enable_slep006") | ||
| def test_graphical_lasso_cv_scores_with_routing(global_random_seed): | ||
OmarManzoor marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need this on top of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a test which checks when we route the groups parameter the graphical lasso cv still operates correctly and gives the desired results.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but we don't have the equivalent test for all the other
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't see what this test brings on top of the common test we have
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This checks for the actual results and values to ensure the output is as desired after routing parameters.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think even if this test might be additional or extra maybe it would still be nice to have it since it compares the values? Let me know what you think @adrinjalali @glemaitre . Otherwise we can remove it to finalize this PR. |
||
| """Check that `GraphicalLassoCV` internally dispatches metadata to | ||
| the splitter. | ||
| """ | ||
| splits = 5 | ||
| n_alphas = 5 | ||
| n_refinements = 3 | ||
| true_cov = np.array( | ||
| [ | ||
| [0.8, 0.0, 0.2, 0.0], | ||
| [0.0, 0.4, 0.0, 0.0], | ||
| [0.2, 0.0, 0.3, 0.1], | ||
| [0.0, 0.0, 0.1, 0.7], | ||
| ] | ||
| ) | ||
| rng = np.random.RandomState(global_random_seed) | ||
| X = rng.multivariate_normal(mean=[0, 0, 0, 0], cov=true_cov, size=300) | ||
| n_samples = X.shape[0] | ||
| groups = rng.randint(0, 5, n_samples) | ||
| params = {"groups": groups} | ||
| cv = GroupKFold(n_splits=splits) | ||
| cv.set_split_request(groups=True) | ||
|
|
||
| cov = GraphicalLassoCV(cv=cv, alphas=n_alphas, n_refinements=n_refinements).fit( | ||
| X, **params | ||
| ) | ||
|
|
||
| _assert_graphical_lasso_cv_scores( | ||
| cov=cov, | ||
| n_splits=splits, | ||
| n_refinements=n_refinements, | ||
| n_alphas=n_alphas, | ||
| ) | ||
|
|
||
|
|
||
| def _assert_graphical_lasso_cv_scores(cov, n_splits, n_refinements, n_alphas): | ||
| cv_results = cov.cv_results_ | ||
| # alpha and one for each split | ||
|
|
||
| total_alphas = n_refinements * n_alphas + 1 | ||
| keys = ["alphas"] | ||
| split_keys = [f"split{i}_test_score" for i in range(splits)] | ||
| split_keys = [f"split{i}_test_score" for i in range(n_splits)] | ||
| for key in keys + split_keys: | ||
| assert key in cv_results | ||
| assert len(cv_results[key]) == total_alphas | ||
|
|
@@ -270,17 +330,3 @@ def test_graphical_lasso_cv_scores(): | |
|
|
||
| assert_allclose(cov.cv_results_["mean_test_score"], expected_mean) | ||
| assert_allclose(cov.cv_results_["std_test_score"], expected_std) | ||
|
|
||
|
|
||
| # TODO(1.5): remove in 1.5 | ||
| def test_graphical_lasso_cov_init_deprecation(): | ||
| """Check that we raise a deprecation warning if providing `cov_init` in | ||
| `graphical_lasso`.""" | ||
| rng, dim, n_samples = np.random.RandomState(0), 20, 100 | ||
| prec = make_sparse_spd_matrix(dim, alpha=0.95, random_state=0) | ||
| cov = linalg.inv(prec) | ||
| X = rng.multivariate_normal(np.zeros(dim), cov, size=n_samples) | ||
|
|
||
| emp_cov = empirical_covariance(X) | ||
| with pytest.warns(FutureWarning, match="cov_init parameter is deprecated"): | ||
| graphical_lasso(emp_cov, alpha=0.1, cov_init=emp_cov) | ||
Uh oh!
There was an error while loading. Please reload this page.