diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index c658bc6b12452..d56914f874b42 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -560,6 +560,7 @@ From text feature_selection.chi2 feature_selection.f_classif feature_selection.f_regression + feature_selection.r_regression feature_selection.mutual_info_classif feature_selection.mutual_info_regression diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index a566d03ae1bbc..eaf02942cf316 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -103,6 +103,14 @@ Changelog input strings would result in negative indices in the transformed data. :pr:`19035` by :user:`Liu Yu `. +:mod:`sklearn.feature_selection` +................................ + +- |Feature| :func:`feature_selection.r_regression` computes Pearson's R + correlation coefficients between the features and the target. + :pr:`17169` by `Dmytro Lituiev ` + and `Julien Jerphanion `. + :mod:`sklearn.inspection` ......................... diff --git a/sklearn/feature_selection/__init__.py b/sklearn/feature_selection/__init__.py index 86e8a2af39084..ef894b40065de 100644 --- a/sklearn/feature_selection/__init__.py +++ b/sklearn/feature_selection/__init__.py @@ -8,6 +8,7 @@ from ._univariate_selection import f_classif from ._univariate_selection import f_oneway from ._univariate_selection import f_regression +from ._univariate_selection import r_regression from ._univariate_selection import SelectPercentile from ._univariate_selection import SelectKBest from ._univariate_selection import SelectFpr @@ -44,6 +45,7 @@ 'f_classif', 'f_oneway', 'f_regression', + 'r_regression', 'mutual_info_classif', 'mutual_info_regression', 'SelectorMixin'] diff --git a/sklearn/feature_selection/_univariate_selection.py b/sklearn/feature_selection/_univariate_selection.py index 0656e27d6e30f..7fc69a4b13cf2 100644 --- a/sklearn/feature_selection/_univariate_selection.py +++ b/sklearn/feature_selection/_univariate_selection.py @@ -229,60 +229,53 @@ def chi2(X, y): return _chisquare(observed, expected) -@_deprecate_positional_args -def f_regression(X, y, *, center=True): - """Univariate linear regression tests. +def r_regression(X, y, *, center=True): + """Compute Pearson's r for each features and the target. + + Pearson's r is also known as the Pearson correlation coefficient. + + .. versionadded:: 1.0 Linear model for testing the individual effect of each of many regressors. This is a scoring function to be used in a feature selection procedure, not a free standing feature selection procedure. - This is done in 2 steps: - - 1. The correlation between each regressor and the target is computed, - that is, ((X[:, i] - mean(X[:, i])) * (y - mean_y)) / (std(X[:, i]) * - std(y)). - 2. It is converted to an F score then to a p-value. + The cross correlation between each regressor and the target is computed + as ((X[:, i] - mean(X[:, i])) * (y - mean_y)) / (std(X[:, i]) * std(y)). For more on usage see the :ref:`User Guide `. Parameters ---------- - X : {array-like, sparse matrix} shape = (n_samples, n_features) - The set of regressors that will be tested sequentially. + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The data matrix. - y : array of shape(n_samples). - The data matrix + y : array-like of shape (n_samples,) + The target vector. center : bool, default=True - If true, X and y will be centered. + Whether or not to center the data matrix `X` and the target vector `y`. + By default, `X` and `y` will be centered. Returns ------- - F : array, shape=(n_features,) - F values of features. - - pval : array, shape=(n_features,) - p-values of F-scores. + correlation_coefficient : ndarray of shape (n_features,) + Pearson's R correlation coefficients of features. See Also -------- - mutual_info_regression : Mutual information for a continuous target. - f_classif : ANOVA F-value between label/feature for classification tasks. - chi2 : Chi-squared stats of non-negative features for classification tasks. - SelectKBest : Select features based on the k highest scores. - SelectFpr : Select features based on a false positive rate test. - SelectFdr : Select features based on an estimated false discovery rate. - SelectFwe : Select features based on family-wise error rate. - SelectPercentile : Select features based on percentile of the highest - scores. + f_regression: Univariate linear regression tests returning f-statistic + and p-values + mutual_info_regression: Mutual information for a continuous target. + f_classif: ANOVA F-value between label/feature for classification tasks. + chi2: Chi-squared stats of non-negative features for classification tasks. """ X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], dtype=np.float64) n_samples = X.shape[0] - # compute centered values - # note that E[(x - mean(x))*(y - mean(y))] = E[x*(y - mean(y))], so we + # Compute centered values + # Note that E[(x - mean(x))*(y - mean(y))] = E[x*(y - mean(y))], so we # need not center X if center: y = y - np.mean(y) @@ -290,22 +283,86 @@ def f_regression(X, y, *, center=True): X_means = X.mean(axis=0).getA1() else: X_means = X.mean(axis=0) - # compute the scaled standard deviations via moments + # Compute the scaled standard deviations via moments X_norms = np.sqrt(row_norms(X.T, squared=True) - n_samples * X_means ** 2) else: X_norms = row_norms(X.T) - # compute the correlation - corr = safe_sparse_dot(y, X) - corr /= X_norms - corr /= np.linalg.norm(y) + correlation_coefficient = safe_sparse_dot(y, X) + correlation_coefficient /= X_norms + correlation_coefficient /= np.linalg.norm(y) + return correlation_coefficient + + +@_deprecate_positional_args +def f_regression(X, y, *, center=True): + """Univariate linear regression tests returning F-statistic and p-values. + + Quick linear model for testing the effect of a single regressor, + sequentially for many regressors. + + This is done in 2 steps: + + 1. The cross correlation between each regressor and the target is computed, + that is, ((X[:, i] - mean(X[:, i])) * (y - mean_y)) / (std(X[:, i]) * + std(y)) using r_regression function. + 2. It is converted to an F score and then to a p-value. + + :func:`f_regression` is derived from :func:`r_regression` and will rank + features in the same order if all the features are positively correlated + with the target. + + Note however that contrary to :func:`f_regression`, :func:`r_regression` + values lie in [-1, 1] and can thus be negative. :func:`f_regression` is + therefore recommended as a feature selection criterion to identify + potentially predictive feature for a downstream classifier, irrespective of + the sign of the association with the target variable. + + Furthermore :func:`f_regression` returns p-values while + :func:`r_regression` does not. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The data matrix. + + y : array-like of shape (n_samples,) + The target vector. + + center : bool, default=True + Whether or not to center the data matrix `X` and the target vector `y`. + By default, `X` and `y` will be centered. + + Returns + ------- + f_statistic : ndarray of shape (n_features,) + F-statistic for each feature. + + p_values : ndarray of shape (n_features,) + P-values associated with the F-statistic. + + See Also + -------- + r_regression: Pearson's R between label/feature for regression tasks. + f_classif: ANOVA F-value between label/feature for classification tasks. + chi2: Chi-squared stats of non-negative features for classification tasks. + SelectKBest: Select features based on the k highest scores. + SelectFpr: Select features based on a false positive rate test. + SelectFdr: Select features based on an estimated false discovery rate. + SelectFwe: Select features based on family-wise error rate. + SelectPercentile: Select features based on percentile of the highest + scores. + """ + correlation_coefficient = r_regression(X, y, center=center) + deg_of_freedom = y.size - (2 if center else 1) - # convert to p-value - degrees_of_freedom = y.size - (2 if center else 1) - F = corr ** 2 / (1 - corr ** 2) * degrees_of_freedom - pv = stats.f.sf(F, 1, degrees_of_freedom) - return F, pv + corr_coef_squared = correlation_coefficient ** 2 + f_statistic = corr_coef_squared / (1 - corr_coef_squared) * deg_of_freedom + p_values = stats.f.sf(f_statistic, 1, deg_of_freedom) + return f_statistic, p_values ###################################################################### @@ -502,12 +559,12 @@ class SelectKBest(_BaseFilter): See Also -------- - f_classif : ANOVA F-value between label/feature for classification tasks. - mutual_info_classif : Mutual information for a discrete target. - chi2 : Chi-squared stats of non-negative features for classification tasks. - f_regression : F-value between label/feature for regression tasks. - mutual_info_regression : Mutual information for a continuous target. - SelectPercentile : Select features based on percentile of the highest + f_classif: ANOVA F-value between label/feature for classification tasks. + mutual_info_classif: Mutual information for a discrete target. + chi2: Chi-squared stats of non-negative features for classification tasks. + f_regression: F-value between label/feature for regression tasks. + mutual_info_regression: Mutual information for a continuous target. + SelectPercentile: Select features based on percentile of the highest scores. SelectFpr : Select features based on a false positive rate test. SelectFdr : Select features based on an estimated false discovery rate. diff --git a/sklearn/feature_selection/tests/test_feature_select.py b/sklearn/feature_selection/tests/test_feature_select.py index 61f709094147e..852c8228b2a76 100644 --- a/sklearn/feature_selection/tests/test_feature_select.py +++ b/sklearn/feature_selection/tests/test_feature_select.py @@ -4,11 +4,12 @@ import itertools import warnings import numpy as np +from numpy.testing import assert_allclose from scipy import stats, sparse import pytest -from sklearn.utils._testing import assert_almost_equal +from sklearn.utils._testing import assert_almost_equal, _convert_container from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_warns @@ -18,9 +19,20 @@ from sklearn.datasets import make_classification, make_regression from sklearn.feature_selection import ( - chi2, f_classif, f_oneway, f_regression, mutual_info_classif, - mutual_info_regression, SelectPercentile, SelectKBest, SelectFpr, - SelectFdr, SelectFwe, GenericUnivariateSelect) + chi2, + f_classif, + f_oneway, + f_regression, + GenericUnivariateSelect, + mutual_info_classif, + mutual_info_regression, + r_regression, + SelectPercentile, + SelectKBest, + SelectFpr, + SelectFdr, + SelectFwe, +) ############################################################################## @@ -71,6 +83,27 @@ def test_f_classif(): assert_array_almost_equal(pv_sparse, pv) +@pytest.mark.parametrize("center", [True, False]) +def test_r_regression(center): + X, y = make_regression(n_samples=2000, n_features=20, n_informative=5, + shuffle=False, random_state=0) + + corr_coeffs = r_regression(X, y, center=center) + assert ((-1 < corr_coeffs).all()) + assert ((corr_coeffs < 1).all()) + + sparse_X = _convert_container(X, "sparse") + + sparse_corr_coeffs = r_regression(sparse_X, y, center=center) + assert_allclose(sparse_corr_coeffs, corr_coeffs) + + # Testing against numpy for reference + Z = np.hstack((X, y[:, np.newaxis])) + correlation_matrix = np.corrcoef(Z, rowvar=False) + np_corr_coeffs = correlation_matrix[:-1, -1] + assert_array_almost_equal(np_corr_coeffs, corr_coeffs, decimal=3) + + def test_f_regression(): # Test whether the F test yields meaningful results # on a simple simulated regression problem @@ -87,14 +120,14 @@ def test_f_regression(): # with centering, compare with sparse F, pv = f_regression(X, y, center=True) F_sparse, pv_sparse = f_regression(sparse.csr_matrix(X), y, center=True) - assert_array_almost_equal(F_sparse, F) - assert_array_almost_equal(pv_sparse, pv) + assert_allclose(F_sparse, F) + assert_allclose(pv_sparse, pv) # again without centering, compare with sparse F, pv = f_regression(X, y, center=False) F_sparse, pv_sparse = f_regression(sparse.csr_matrix(X), y, center=False) - assert_array_almost_equal(F_sparse, F) - assert_array_almost_equal(pv_sparse, pv) + assert_allclose(F_sparse, F) + assert_allclose(pv_sparse, pv) def test_f_regression_input_dtype(): @@ -106,8 +139,8 @@ def test_f_regression_input_dtype(): F1, pv1 = f_regression(X, y) F2, pv2 = f_regression(X, y.astype(float)) - assert_array_almost_equal(F1, F2, 5) - assert_array_almost_equal(pv1, pv2, 5) + assert_allclose(F1, F2, 5) + assert_allclose(pv1, pv2, 5) def test_f_regression_center(): @@ -123,7 +156,7 @@ def test_f_regression_center(): F1, _ = f_regression(X, Y, center=True) F2, _ = f_regression(X, Y, center=False) - assert_array_almost_equal(F1 * (n_samples - 1.) / (n_samples - 2.), F2) + assert_allclose(F1 * (n_samples - 1.) / (n_samples - 2.), F2) assert_almost_equal(F2[0], 0.232558139) # value from statsmodels OLS @@ -262,7 +295,7 @@ def test_select_heuristics_classif(): f_classif, mode=mode, param=0.01).fit(X, y).transform(X) assert_array_equal(X_r, X_r2) support = univariate_filter.get_support() - assert_array_almost_equal(support, gtruth) + assert_allclose(support, gtruth) ############################################################################## @@ -272,7 +305,7 @@ def test_select_heuristics_classif(): def assert_best_scores_kept(score_filter): scores = score_filter.scores_ support = score_filter.get_support() - assert_array_almost_equal(np.sort(scores[support]), + assert_allclose(np.sort(scores[support]), np.sort(scores)[-support.sum():])