-
-
Notifications
You must be signed in to change notification settings - Fork 26k
ENH allow shrunk_covariance to handle multiple matrices at once #25275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
a1e27d1
update shrunk_covariance
qbarthelemy 3ff8c2e
complete tests
qbarthelemy 09985c1
correct n_features
qbarthelemy 9f2c51f
complete changelog
qbarthelemy 44782e8
change quote
qbarthelemy cd75f5a
allow nd in check array
qbarthelemy 09f8f8a
Apply suggestions from code review
qbarthelemy ab8aacb
correct log and correct expand_dims
qbarthelemy 22beec4
split tests and parametrize
qbarthelemy cf1a82c
use assert_allclose
qbarthelemy de0ba82
update whats new
qbarthelemy 61e1d7e
correct whats new
qbarthelemy c7a3dbd
Merge remote-tracking branch 'origin/main' into pr/qbarthelemy/25275
glemaitre 4d88273
Merge branch 'main' into covariance_shrunk
jeremiedbb File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,23 +109,23 @@ def _oas(X, *, assume_centered=False): | |
prefer_skip_nested_validation=True, | ||
) | ||
def shrunk_covariance(emp_cov, shrinkage=0.1): | ||
"""Calculate a covariance matrix shrunk on the diagonal. | ||
"""Calculate covariance matrices shrunk on the diagonal. | ||
|
||
Read more in the :ref:`User Guide <shrunk_covariance>`. | ||
|
||
Parameters | ||
---------- | ||
emp_cov : array-like of shape (n_features, n_features) | ||
Covariance matrix to be shrunk. | ||
emp_cov : array-like of shape (..., n_features, n_features) | ||
Covariance matrices to be shrunk, at least 2D ndarray. | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
shrinkage : float, default=0.1 | ||
Coefficient in the convex combination used for the computation | ||
of the shrunk estimate. Range is [0, 1]. | ||
|
||
Returns | ||
------- | ||
shrunk_cov : ndarray of shape (n_features, n_features) | ||
Shrunk covariance. | ||
shrunk_cov : ndarray of shape (..., n_features, n_features) | ||
Shrunk covariance matrices. | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Notes | ||
----- | ||
|
@@ -135,12 +135,13 @@ def shrunk_covariance(emp_cov, shrinkage=0.1): | |
|
||
where `mu = trace(cov) / n_features`. | ||
""" | ||
emp_cov = check_array(emp_cov) | ||
n_features = emp_cov.shape[0] | ||
emp_cov = check_array(emp_cov, allow_nd=True) | ||
n_features = emp_cov.shape[-1] | ||
|
||
mu = np.trace(emp_cov) / n_features | ||
shrunk_cov = (1.0 - shrinkage) * emp_cov | ||
shrunk_cov.flat[:: n_features + 1] += shrinkage * mu | ||
mu = np.trace(emp_cov, axis1=-2, axis2=-1) / n_features | ||
mu = np.expand_dims(mu, axis=tuple(range(mu.ndim, emp_cov.ndim))) | ||
shrunk_cov += shrinkage * mu * np.eye(n_features) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, we materialize the |
||
|
||
return shrunk_cov | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestions about docstring seem to restrict usage to 2D and 3D arrays, whereas function is now able to process any nd array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the use case of n-D shrinkage? Since the covariance matrix is always a 2-D matrix, I don't really see when you will pass a 4-D array, for instance.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of NumPy and SciPy functions use
(..., n, n)
to indicate that they can process ndarrays,like numpy.linalg.eig and scipy.linalg.expm for example.
Use cases belong to users. But, for an example with a 4D array, shape
(k, m, n, n)
, one might want to shrunkk
sets, each set containing ofm
covariance matrices. I have tested, and code is ok.For me, description should not restrain actual usage. But, as you wish.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking more of restraining on purpose the usage that we have in scikit-learn.
But we can go this road and see what other reviewers think.