diff --git a/doc/modules/covariance.rst b/doc/modules/covariance.rst index c97676ea62108..50927f9a677f6 100644 --- a/doc/modules/covariance.rst +++ b/doc/modules/covariance.rst @@ -160,8 +160,10 @@ object to the same sample. .. topic:: References: - .. [2] Chen et al., "Shrinkage Algorithms for MMSE Covariance Estimation", - IEEE Trans. on Sign. Proc., Volume 58, Issue 10, October 2010. + .. [2] :arxiv:`"Shrinkage algorithms for MMSE covariance estimation.", + Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O. + IEEE Transactions on Signal Processing, 58(10), 5016-5029, 2010. + <0907.4698>` .. topic:: Examples: diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index 73a7ce233981b..5cdc9f3d212ad 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -44,9 +44,16 @@ def _ledoit_wolf(X, *, assume_centered, block_size): def _oas(X, *, assume_centered=False): - """Estimate covariance with the Oracle Approximating Shrinkage algorithm.""" - # for only one feature, the result is the same whatever the shrinkage + """Estimate covariance with the Oracle Approximating Shrinkage algorithm. + + The formulation is based on [1]_. + [1] "Shrinkage algorithms for MMSE covariance estimation.", + Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O. + IEEE Transactions on Signal Processing, 58(10), 5016-5029, 2010. + https://arxiv.org/pdf/0907.4698.pdf + """ if len(X.shape) == 2 and X.shape[1] == 1: + # for only one feature, the result is the same whatever the shrinkage if not assume_centered: X = X - X.mean() return np.atleast_2d((X**2).mean()), 0.0 @@ -54,14 +61,33 @@ def _oas(X, *, assume_centered=False): n_samples, n_features = X.shape emp_cov = empirical_covariance(X, assume_centered=assume_centered) - mu = np.trace(emp_cov) / n_features - # formula from Chen et al.'s **implementation** + # The shrinkage is defined as: + # shrinkage = min( + # trace(S @ S.T) + trace(S)**2) / ((n + 1) (trace(S @ S.T) - trace(S)**2 / p), 1 + # ) + # where n and p are n_samples and n_features, respectively (cf. Eq. 23 in [1]). + # The factor 2 / p is omitted since it does not impact the value of the estimator + # for large p. + + # Instead of computing trace(S)**2, we can compute the average of the squared + # elements of S that is equal to trace(S)**2 / p**2. + # See the definition of the Frobenius norm: + # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm alpha = np.mean(emp_cov**2) - num = alpha + mu**2 - den = (n_samples + 1.0) * (alpha - (mu**2) / n_features) + mu = np.trace(emp_cov) / n_features + mu_squared = mu**2 + # The factor 1 / p**2 will cancel out since it is in both the numerator and + # denominator + num = alpha + mu_squared + den = (n_samples + 1) * (alpha - mu_squared / n_features) shrinkage = 1.0 if den == 0 else min(num / den, 1.0) + + # The shrunk covariance is defined as: + # (1 - shrinkage) * S + shrinkage * F (cf. Eq. 4 in [1]) + # where S is the empirical covariance and F is the shrinkage target defined as + # F = trace(S) / n_features * np.identity(n_features) (cf. Eq. 3 in [1]) shrunk_cov = (1.0 - shrinkage) * emp_cov shrunk_cov.flat[:: n_features + 1] += shrinkage * mu @@ -536,7 +562,9 @@ def fit(self, X, y=None): # OAS estimator @validate_params({"X": ["array-like"]}) def oas(X, *, assume_centered=False): - """Estimate covariance with the Oracle Approximating Shrinkage algorithm. + """Estimate covariance with the Oracle Approximating Shrinkage as proposed in [1]_. + + Read more in the :ref:`User Guide `. Parameters ---------- @@ -560,14 +588,25 @@ def oas(X, *, assume_centered=False): Notes ----- - The regularised (shrunk) covariance is: + The regularised covariance is: - (1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features) + (1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features), - where mu = trace(cov) / n_features + where mu = trace(cov) / n_features and shrinkage is given by the OAS formula + (see [1]_). + + The shrinkage formulation implemented here differs from Eq. 23 in [1]_. In + the original article, formula (23) states that 2/p (p being the number of + features) is multiplied by Trace(cov*cov) in both the numerator and + denominator, but this operation is omitted because for a large p, the value + of 2/p is so small that it doesn't affect the value of the estimator. - The formula we used to implement the OAS is slightly modified compared - to the one given in the article. See :class:`OAS` for more details. + References + ---------- + .. [1] :arxiv:`"Shrinkage algorithms for MMSE covariance estimation.", + Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O. + IEEE Transactions on Signal Processing, 58(10), 5016-5029, 2010. + <0907.4698>` """ estimator = OAS( assume_centered=assume_centered, @@ -576,20 +615,10 @@ def oas(X, *, assume_centered=False): class OAS(EmpiricalCovariance): - """Oracle Approximating Shrinkage Estimator. + """Oracle Approximating Shrinkage Estimator as proposed in [1]_. Read more in the :ref:`User Guide `. - OAS is a particular form of shrinkage described in - "Shrinkage Algorithms for MMSE Covariance Estimation" - Chen et al., IEEE Trans. on Sign. Proc., Volume 58, Issue 10, October 2010. - - The formula used here does not correspond to the one given in the - article. In the original article, formula (23) states that 2/p is - multiplied by Trace(cov*cov) in both the numerator and denominator, but - this operation is omitted because for a large p, the value of 2/p is - so small that it doesn't affect the value of the estimator. - Parameters ---------- store_precision : bool, default=True @@ -646,15 +675,23 @@ class OAS(EmpiricalCovariance): ----- The regularised covariance is: - (1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features) + (1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features), - where mu = trace(cov) / n_features - and shrinkage is given by the OAS formula (see References) + where mu = trace(cov) / n_features and shrinkage is given by the OAS formula + (see [1]_). + + The shrinkage formulation implemented here differs from Eq. 23 in [1]_. In + the original article, formula (23) states that 2/p (p being the number of + features) is multiplied by Trace(cov*cov) in both the numerator and + denominator, but this operation is omitted because for a large p, the value + of 2/p is so small that it doesn't affect the value of the estimator. References ---------- - "Shrinkage Algorithms for MMSE Covariance Estimation" - Chen et al., IEEE Trans. on Sign. Proc., Volume 58, Issue 10, October 2010. + .. [1] :arxiv:`"Shrinkage algorithms for MMSE covariance estimation.", + Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O. + IEEE Transactions on Signal Processing, 58(10), 5016-5029, 2010. + <0907.4698>` Examples --------