From df08defb051703f9961f19775aee38869cb44769 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Wed, 23 Oct 2013 13:02:08 -0400 Subject: [PATCH 01/25] ENH: Implemented mutual_info function --- doc/modules/feature_selection.rst | 2 +- examples/feature_selection/plot_rfe_digits.py | 2 +- sklearn/feature_selection/__init__.py | 6 +- sklearn/feature_selection/mutual_info.py | 297 ++++++++++++++++++ .../tests/test_feature_select.py | 67 +++- .../tests/test_mutual_info.py | 172 ++++++++++ .../feature_selection/univariate_selection.py | 29 +- 7 files changed, 560 insertions(+), 15 deletions(-) create mode 100644 sklearn/feature_selection/mutual_info.py create mode 100644 sklearn/feature_selection/tests/test_mutual_info.py diff --git a/doc/modules/feature_selection.rst b/doc/modules/feature_selection.rst index 60e4d0a38f7c8..679db067ebf8a 100644 --- a/doc/modules/feature_selection.rst +++ b/doc/modules/feature_selection.rst @@ -315,4 +315,4 @@ Then, a :class:`sklearn.ensemble.RandomForestClassifier` is trained on the transformed output, i.e. using only relevant features. You can perform similar operations with the other feature selection methods and also classifiers that provide a way to evaluate feature importances of course. -See the :class:`sklearn.pipeline.Pipeline` examples for more details. +See the :class:`sklearn.pipeline.Pipeline` examples for more details. \ No newline at end of file diff --git a/examples/feature_selection/plot_rfe_digits.py b/examples/feature_selection/plot_rfe_digits.py index 626a25afef231..2427944a2f112 100644 --- a/examples/feature_selection/plot_rfe_digits.py +++ b/examples/feature_selection/plot_rfe_digits.py @@ -33,4 +33,4 @@ plt.matshow(ranking, cmap=plt.cm.Blues) plt.colorbar() plt.title("Ranking of pixels with RFE") -plt.show() +plt.show() \ No newline at end of file diff --git a/sklearn/feature_selection/__init__.py b/sklearn/feature_selection/__init__.py index acb03f6f24a9e..e41fa12cd94e8 100644 --- a/sklearn/feature_selection/__init__.py +++ b/sklearn/feature_selection/__init__.py @@ -22,6 +22,9 @@ from .from_model import SelectFromModel +from .mutual_info import mutual_info + + __all__ = ['GenericUnivariateSelect', 'RFE', 'RFECV', @@ -29,10 +32,11 @@ 'SelectFpr', 'SelectFwe', 'SelectKBest', + 'SelectFromModel' 'SelectPercentile', 'VarianceThreshold', 'chi2', 'f_classif', 'f_oneway', 'f_regression', - 'SelectFromModel'] + 'mutual_info'] diff --git a/sklearn/feature_selection/mutual_info.py b/sklearn/feature_selection/mutual_info.py new file mode 100644 index 0000000000000..3638712ef7e74 --- /dev/null +++ b/sklearn/feature_selection/mutual_info.py @@ -0,0 +1,297 @@ +# Author: Nikolay Mayorov +# License: 3-clause BSD +from __future__ import division + +import numpy as np +from scipy.sparse import issparse +from scipy.special import digamma + +from ..metrics.cluster.supervised import mutual_info_score +from ..neighbors import NearestNeighbors +from ..preprocessing import scale +from ..utils import check_random_state +from ..utils.validation import check_X_y + + +def _compute_mi_cc(x, y, n_neighbors): + """Compute mutual information between two continuous variables. + + Parameters + ---------- + x, y : ndarray + Samples from random variables, 1-d arrays of identical shape. + n_neighbors : int + Number of nearest neighbors to search for each point, see [1]_. + + Returns + ------- + mi : float + Estimated mutual information. If it turned out to be negative it is + replace by 0. + + Notes + ----- + True mutual information can't be negative. If its estimate by a numerical + method is negative, it means (providing the method is adequate) that the + mutual information is close to 0 and replacing it by 0 is a reasonable + strategy. + + References + ---------- + .. [1] A. Kraskov, H. Stogbauer and P. Grassberger, "Estimating mutual + information". Phys. Rev. E 69, 2004. + """ + n_samples = x.size + + x = x.reshape((-1, 1)) + y = y.reshape((-1, 1)) + xy = np.hstack((x, y)) + + nn = NearestNeighbors(metric='chebyshev', n_neighbors=n_neighbors) + + nn.fit(xy) + radius = nn.kneighbors()[0] + radius = np.nextafter(radius[:, -1], 0) + + nn.set_params(algorithm='kd_tree') + + nn.fit(x) + ind = nn.radius_neighbors(radius=radius, return_distance=False) + nx = np.array([i.size for i in ind]) + + nn.fit(y) + ind = nn.radius_neighbors(radius=radius, return_distance=False) + ny = np.array([i.size for i in ind]) + + mi = (digamma(n_samples) + digamma(n_neighbors) - + np.mean(digamma(nx + 1)) - np.mean(digamma(ny + 1))) + + return max(0, mi) + + +def _compute_mi_cd(c, d, n_neighbors): + """Compute mutual information between continuous and discrete variables. + + Parameters + ---------- + c : ndarray + Samples from a continuous random variable. + d : ndarray + Samples from a discrete random variable. + n_neighbors : int + Number of nearest neighbors to search for each point, see [1]_. + + Returns + ------- + mi : float + Estimated mutual information. If it turned out to be negative it is + replace by 0. + + Notes + ----- + True mutual information can't be negative. If its estimate by a numerical + method is negative, it means (providing the method is adequate) that the + mutual information is close to 0 and replacing it by 0 is a reasonable + strategy. + + References + ---------- + .. [1] B. C. Ross "Mutual Information between Discrete and Continuous + Data Sets". PLoS ONE 9(2), 2014. + """ + c = c.reshape((-1, 1)) + n_samples = c.size + + nn = NearestNeighbors(n_neighbors=n_neighbors) + + radius = np.empty(n_samples) + label_counts = np.empty(n_samples) + for label in np.unique(d): + mask = d == label + count = np.sum(mask) + if count > 1: + nn.set_params(n_neighbors=min(n_neighbors, count - 1)) + nn.fit(c[mask]) + r = nn.kneighbors()[0] + radius[mask] = np.nextafter(r[:, -1], 0) + else: + radius[mask] = 0 + label_counts[mask] = count + + nn.set_params(algorithm='kd_tree') + nn.fit(c) + ind = nn.radius_neighbors(radius=radius, return_distance=False) + neighbor_counts = np.array([i.size for i in ind]) + + mi = (digamma(n_samples) + digamma(n_neighbors) - + np.mean(digamma(label_counts)) - + np.mean(digamma(neighbor_counts + 1))) + + return max(0, mi) + + +def _compute_mi(x, y, x_discrete, y_discrete, n_neighbors=3): + """Compute mutual information between two variables. + + This is a simple wrapper which selects a proper function to call based on + whether `x` and `y` are discrete or not. + """ + if x_discrete and y_discrete: + return mutual_info_score(x, y) + elif x_discrete and not y_discrete: + return _compute_mi_cd(y, x, n_neighbors) + elif not x_discrete and y_discrete: + return _compute_mi_cd(x, y, n_neighbors) + else: + return _compute_mi_cc(x, y, n_neighbors) + + +def _get_column(X, i): + """Get column of a matrix. + + Parameters + ---------- + X : ndarray or csc_matrix, shape (n_samples, n_features) + Matrix from which to get a column. + i : int + Column index. + + Returns + ------- + xi : ndarray, shape (n_samples,) + i-th column of `X` in dense format. + """ + if issparse(X): + x = np.zeros(X.shape[0]) + start_ptr, end_ptr = X.indptr[i], X.indptr[i + 1] + x[X.indices[start_ptr:end_ptr]] = X.data[start_ptr:end_ptr] + else: + x = X[:, i] + + return x + + +def _iterate_columns(X, columns=None): + """Iterate over columns of a matrix. + + Parameters + ---------- + X : ndarray or csc_matrix, shape (n_samples, n_features) + Matrix over which to iterate. + columns : iterable or None, default None + Indices of columns to iterate over. If None, iterate over all columns. + + Yields + ------ + x : ndarray, shape (n_samples,) + Columns of `X` in dense format. + """ + if columns is None: + columns = range(X.shape[1]) + + for i in columns: + yield _get_column(X, i) + + +def mutual_info(X, y, discrete_features='auto', discrete_target=False, + n_neighbors=3, copy=True, random_state=None): + """Estimate mutual information (MI) between the features and the target. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + X : array_like or sparse matrix, shape (n_samples, n_features) + Feature matrix. + y : array_like, shape (n_samples,) + Target vector. + discrete_features : {'auto', bool, array_like}, default 'auto' + If bool, then determines whether to consider all features discrete + or continuous. If array, then it should be either a boolean mask + with shape (n_features,) or array with indices of discrete features. + If 'auto', assigned to False for dense `X` and to True for sparse `X`. + discrete_target : bool, default False + Whether to consider `y` as a discrete variable. + n_neighbors : int, default 3 + Number of neighbors to use for MI estimation for continuous variables, + see [2]_ and [3]_. Higher values reduce variance of the estimation, but + could increase a bias. + copy : bool, default True + Whether to make a copy of the given data. If set to False, the initial + data will be overwritten. + random_state : int seed, RandomState instance, or None (default=None) + The seed of the pseudo random number generator for adding small noise + to continuous variables in order to remove repeated values. + + Returns + ------- + mi : ndarray, shape (n_features,) + Estimated mutual information between each feature and the target. + + Notes + ----- + 1. Terms "discrete feature" and "discrete target" are used instead of + naming them as "categorical", because it describes the essence more + accurately. For example, pixel intensities of an image are discrete + features (but hardly categorical) and you will get better results if + mark them as such. Also note, that treating a continuous variable as + discrete and vice versa will usually give meaningless results, so be + attentive about that. + 2. True mutual information can't be negative. If its estimate turned out + to be negative, it will be replaced by zero. + + References + ---------- + .. [1] H. Peng, F. Long, and C. Ding, "Feature selection based on mutual + information: criteria of max-dependency, max-relevance, and + min-redundancy", IEEE Transactions on Pattern Analysis and Machine + Intelligence, Vol. 27, No. 8, pp. 1226-1238, 2005. + .. [2] A. Kraskov, H. Stogbauer and P. Grassberger, "Estimating mutual + information". Phys. Rev. E 69, 2004. + .. [3] B. C. Ross "Mutual Information between Discrete and Continuous + Data Sets". PLoS ONE 9(2), 2014. + """ + X, y = check_X_y(X, y, accept_sparse='csc', y_numeric=not discrete_target) + n_samples, n_features = X.shape + + if discrete_features == 'auto': + discrete_features = issparse(X) + + if isinstance(discrete_features, bool): + discrete_mask = np.empty(n_features, dtype=bool) + discrete_mask.fill(discrete_features) + else: + discrete_features = np.asarray(discrete_features) + if discrete_features.dtype != 'bool': + discrete_mask = np.zeros(n_features, dtype=bool) + discrete_mask[discrete_features] = True + else: + discrete_mask = discrete_features + + continuous_mask = ~discrete_mask + if np.any(continuous_mask) and issparse(X): + raise ValueError("Sparse matrix `X` can't have continuous features.") + + if copy: + X = X.copy() + + if not discrete_target: + X[:, continuous_mask] = scale(X[:, continuous_mask], + with_mean=False, copy=False) + + # Add small noise to continuous features as advised in Kraskov et. al. + rng = check_random_state(random_state) + if np.any(continuous_mask): + X = X.astype(float) + means = np.maximum(1, np.mean(np.abs(X[:, continuous_mask]), axis=0)) + X[:, continuous_mask] += 1e-10 * means * rng.randn( + n_samples, np.sum(continuous_mask)) + + if not discrete_target: + y = scale(y, with_mean=False) + y += 1e-10 * np.maximum(1, np.mean(np.abs(y))) * rng.randn(n_samples) + + mi = [_compute_mi(x, y, discrete_feature, discrete_target) for + x, discrete_feature in zip(_iterate_columns(X), discrete_mask)] + + return np.array(mi) diff --git a/sklearn/feature_selection/tests/test_feature_select.py b/sklearn/feature_selection/tests/test_feature_select.py index 204d7c2e25dba..0cbd56c014532 100644 --- a/sklearn/feature_selection/tests/test_feature_select.py +++ b/sklearn/feature_selection/tests/test_feature_select.py @@ -24,10 +24,9 @@ from sklearn.datasets.samples_generator import (make_classification, make_regression) -from sklearn.feature_selection import (chi2, f_classif, f_oneway, f_regression, - SelectPercentile, SelectKBest, - SelectFpr, SelectFdr, SelectFwe, - GenericUnivariateSelect) +from sklearn.feature_selection import ( + chi2, mutual_info, f_classif, f_oneway, f_regression, SelectPercentile, + SelectKBest, SelectFpr, SelectFdr, SelectFwe, GenericUnivariateSelect) ############################################################################## @@ -556,3 +555,63 @@ def test_no_feature_selected(): X_selected = assert_warns_message( UserWarning, 'No features were selected', selector.transform, X) assert_equal(X_selected.shape, (40, 0)) + + +def test_mutual_info_classification(): + X, y = make_classification(n_samples=200, n_features=20, + n_informative=3, n_redundant=2, + n_repeated=0, n_classes=8, + n_clusters_per_class=1, flip_y=0.0, + class_sep=10, shuffle=False, random_state=0) + + score_func = lambda X, y: mutual_info(X, y, discrete_target=True) + + # Test in KBest mode. + univariate_filter = SelectKBest(score_func, k=5) + X_r = univariate_filter.fit(X, y).transform(X) + X_r2 = GenericUnivariateSelect( + score_func, mode='k_best', param=5).fit(X, y).transform(X) + assert_array_equal(X_r, X_r2) + support = univariate_filter.get_support() + gtruth = np.zeros(20) + gtruth[:5] = 1 + assert_array_equal(support, gtruth) + + # Test in Percentile mode. + univariate_filter = SelectPercentile(score_func, percentile=25) + X_r = univariate_filter.fit(X, y).transform(X) + X_r2 = GenericUnivariateSelect( + score_func, mode='percentile', param=25).fit(X, y).transform(X) + assert_array_equal(X_r, X_r2) + support = univariate_filter.get_support() + gtruth = np.zeros(20) + gtruth[:5] = 1 + assert_array_equal(support, gtruth) + + +def test_mutual_info_regression(): + X, y = make_regression(n_samples=200, n_features=20, n_informative=5, + shuffle=False, random_state=0, noise=10) + + # Test in KBest mode. + univariate_filter = SelectKBest(mutual_info, k=5) + X_r = univariate_filter.fit(X, y).transform(X) + assert_best_scores_kept(univariate_filter) + X_r2 = GenericUnivariateSelect( + mutual_info, mode='k_best', param=5).fit(X, y).transform(X) + assert_array_equal(X_r, X_r2) + support = univariate_filter.get_support() + gtruth = np.zeros(20) + gtruth[:5] = 1 + assert_array_equal(support, gtruth) + + # Test in Percentile mode. + univariate_filter = SelectPercentile(mutual_info, percentile=25) + X_r = univariate_filter.fit(X, y).transform(X) + X_r2 = GenericUnivariateSelect( + mutual_info, mode='percentile', param=25).fit(X, y).transform(X) + assert_array_equal(X_r, X_r2) + support = univariate_filter.get_support() + gtruth = np.zeros(20) + gtruth[:5] = 1 + assert_array_equal(support, gtruth) diff --git a/sklearn/feature_selection/tests/test_mutual_info.py b/sklearn/feature_selection/tests/test_mutual_info.py new file mode 100644 index 0000000000000..a4fe0390be212 --- /dev/null +++ b/sklearn/feature_selection/tests/test_mutual_info.py @@ -0,0 +1,172 @@ +from __future__ import division + +import numpy as np +from numpy.testing import run_module_suite +from scipy.sparse import csr_matrix + +from sklearn.utils.testing import (assert_array_equal, assert_almost_equal, + assert_false, assert_true, assert_raises) +from sklearn.feature_selection.mutual_info import mutual_info, _compute_mi + + +class TestMIComputation(object): + def test_dd(self): + # In discrete case computations are straightforward and can be done + # by hand on given vectors. + x = np.array([0, 1, 1, 0, 0]) + y = np.array([1, 0, 0, 0, 1]) + + H_x = H_y = -(3/5) * np.log(3/5) - (2/5) * np.log(2/5) + H_xy = -1/5 * np.log(1/5) - 2/5 * np.log(2/5) - 2/5 * np.log(2/5) + I_xy = H_x + H_y - H_xy + + assert_almost_equal(_compute_mi(x, y, True, True), I_xy) + + def test_cc(self): + # For two continuous variables a good approach is to test on bivariate + # normal distribution, where mutual information is known. + + # Mean of the distribution, irrelevant for mutual information. + mean = np.zeros(2) + + # Setup covariance matrix with correlation coeff. equal 0.5. + sigma_1 = 1 + sigma_2 = 10 + corr = 0.5 + cov = np.array([ + [sigma_1**2, corr * sigma_1 * sigma_2], + [corr * sigma_1 * sigma_2, sigma_2**2] + ]) + + # True theoretical mutual information. + I_theory = (np.log(sigma_1) + np.log(sigma_2) - + 0.5 * np.log(np.linalg.det(cov))) + + np.random.seed(0) + Z = np.random.multivariate_normal(mean, cov, size=1000) + + x, y = Z[:, 0], Z[:, 1] + + # Theory and computed values won't be very close, assert that relative + # error is less than 10%. + for n_neighbors in [3, 5, 7]: + I_computed = _compute_mi(x, y, False, False, n_neighbors) + assert_true(np.abs(I_computed - I_theory) < 0.1 * I_theory) + + def test_cd(self): + # To test define a joint distribution as follows: + # p(x, y) = p(x) p(y | x) + # X ~ Bernoulli(p) + # (Y | x = 0) ~ Uniform(-1, 1) + # (Y | x = 1) ~ Uniform(0, 2) + + # Use the following formula for mutual information: + # I(X; Y) = H(Y) - H(Y | X) + # Two entropies can be computed by hand: + # H(Y) = -(1-p)/2 * ln((1-p)/2) - p/2*log(p/2) - 1/2*log(1/2) + # H(Y | X) = ln(2) + + # Now we need to implement sampling from out distribution, which is + # done easily using conditional distribution logic. + + n_samples = 1000 + np.random.seed(0) + + for p in [0.3, 0.5, 0.7]: + x = np.random.uniform(size=n_samples) > p + + y = np.empty(n_samples) + mask = x == 0 + y[mask] = np.random.uniform(-1, 1, size=np.sum(mask)) + y[~mask] = np.random.uniform(0, 2, size=np.sum(~mask)) + + I_theory = -0.5 * ((1 - p) * np.log(0.5 * (1 - p)) + + p * np.log(0.5 * p) + np.log(0.5)) - np.log(2) + + # Again assert that relative error is less than 10%. + for n_neighbors in [3, 5, 7]: + I_computed = _compute_mi(x, y, True, False, n_neighbors) + assert_true(np.abs(I_computed - I_theory) < 0.1 * I_theory) + + +class TestMutualInfo(object): + # We are going test that feature ordering by MI matches our expectations. + def test_discrete(self): + X = np.array([[0, 0, 0], + [1, 1, 0], + [2, 0, 1], + [2, 0, 1], + [2, 0, 1]]) + y = np.array([0, 1, 2, 2, 1]) + + # Here X[:, 0] is the most informative feature, and X[:, 1] is weakly + # informative. + mi = mutual_info(X, y, discrete_features=True, discrete_target=True) + assert_array_equal(np.argsort(-mi), np.array([0, 2, 1])) + + def test_continuous(self): + # We generate sample from multivariate normal distribution, using + # transformation from initially uncorrelated variables. The first + # variables after transformation is selected as the target vector, + # it has the strongest correlation with the variable 2, and + # the weakest correlation with the variable 1. + T = np.array([ + [1, 0.5, 2, 1], + [0, 1, 0.1, 0.0], + [0, 0.1, 1, 0.1], + [0, 0.1, 0.1, 1] + ]) + cov = T.dot(T.T) + mean = np.zeros(4) + + np.random.seed(0) + Z = np.random.multivariate_normal(mean, cov, size=1000) + X = Z[:, 1:] + y = Z[:, 0] + + mi = mutual_info(X, y, random_state=0) + assert_array_equal(np.argsort(-mi), np.array([1, 2, 0])) + + def test_mixed(self): + # Here the target is discrete and there are two continuous and one + # discrete feature. The idea of this test is clear from the code. + np.random.seed(0) + X = np.random.rand(1000, 3) + X[:, 1] += X[:, 0] + y = ((0.5 * X[:, 0] + X[:, 2]) > 0.5).astype(int) + X[:, 2] = X[:, 2] > 0.5 + + mi = mutual_info(X, y, discrete_features=[2], discrete_target=True, + random_state=0) + assert_array_equal(np.argsort(-mi), [2, 0, 1]) + + def test_discrete_features_option(self): + X = np.array([[0, 0, 0], + [1, 1, 0], + [2, 0, 1], + [2, 0, 1], + [2, 0, 1]]) + y = np.array([0, 1, 2, 2, 1]) + X_csr = csr_matrix(X) + + assert_raises(ValueError, mutual_info, X_csr, y, + discrete_features=False) + + mi_1 = mutual_info(X, y, discrete_features='auto', + discrete_target=True, random_state=0) + mi_2 = mutual_info(X, y, discrete_features=False, + discrete_target=True, random_state=0) + + mi_3 = mutual_info(X_csr, y, discrete_features='auto', + discrete_target=True) + mi_4 = mutual_info(X_csr, y, discrete_features=True, + discrete_target=True) + + assert_array_equal(mi_1, mi_2) + assert_array_equal(mi_3, mi_4) + + assert_false(np.allclose(mi_1, mi_3)) + + +if __name__ == '__main__': + run_module_suite() diff --git a/sklearn/feature_selection/univariate_selection.py b/sklearn/feature_selection/univariate_selection.py index 9bd8ca273a8dc..b74c5cd5928e3 100644 --- a/sklearn/feature_selection/univariate_selection.py +++ b/sklearn/feature_selection/univariate_selection.py @@ -295,7 +295,7 @@ class _BaseFilter(BaseEstimator, SelectorMixin): ---------- score_func : callable Function taking two arrays X and y, and returning a pair of arrays - (scores, pvalues). + (scores, pvalues) or a single array scores. """ def __init__(self, score_func): @@ -326,10 +326,16 @@ def fit(self, X, y): % (self.score_func, type(self.score_func))) self._check_params(X, y) + score_func_ret = self.score_func(X, y) + if isinstance(score_func_ret, (list, tuple)): + self.scores_, self.pvalues_ = score_func_ret + self.pvalues_ = np.asarray(self.pvalues_) + else: + self.scores_ = score_func_ret + self.pvalues_ = None - self.scores_, self.pvalues_ = self.score_func(X, y) self.scores_ = np.asarray(self.scores_) - self.pvalues_ = np.asarray(self.pvalues_) + return self def _check_params(self, X, y): @@ -348,7 +354,7 @@ class SelectPercentile(_BaseFilter): ---------- score_func : callable Function taking two arrays X and y, and returning a pair of arrays - (scores, pvalues). + (scores, pvalues) or a single array scores. percentile : int, optional, default=10 Percent of features to keep. @@ -359,7 +365,7 @@ class SelectPercentile(_BaseFilter): Scores of features. pvalues_ : array-like, shape=(n_features,) - p-values of feature scores. + p-values of feature scores, None if `score_func` returned scores only. Notes ----- @@ -371,6 +377,7 @@ class SelectPercentile(_BaseFilter): f_classif: ANOVA F-value between label/feature for classification tasks. chi2: Chi-squared stats of non-negative features for classification tasks. f_regression: F-value between label/feature for regression tasks. + mutual_info: Mutual information between features and the target. SelectKBest: Select features based on the k highest scores. SelectFpr: Select features based on a false positive rate test. SelectFdr: Select features based on an estimated false discovery rate. @@ -429,7 +436,7 @@ class SelectKBest(_BaseFilter): Scores of features. pvalues_ : array-like, shape=(n_features,) - p-values of feature scores. + p-values of feature scores, None if `score_func` returned scores only. Notes ----- @@ -441,6 +448,7 @@ class SelectKBest(_BaseFilter): f_classif: ANOVA F-value between label/feature for classification tasks. chi2: Chi-squared stats of non-negative features for classification tasks. f_regression: F-value between label/feature for regression tasks. + mutual_info: Mutual information between features and the target. SelectPercentile: Select features based on percentile of the highest scores. SelectFpr: Select features based on a false positive rate test. SelectFdr: Select features based on an estimated false discovery rate. @@ -505,6 +513,7 @@ class SelectFpr(_BaseFilter): f_classif: ANOVA F-value between label/feature for classification tasks. chi2: Chi-squared stats of non-negative features for classification tasks. f_regression: F-value between label/feature for regression tasks. + mutual_info: Mutual information between features and the target. SelectPercentile: Select features based on percentile of the highest scores. SelectKBest: Select features based on the k highest scores. SelectFdr: Select features based on an estimated false discovery rate. @@ -557,6 +566,7 @@ class SelectFdr(_BaseFilter): f_classif: ANOVA F-value between label/feature for classification tasks. chi2: Chi-squared stats of non-negative features for classification tasks. f_regression: F-value between label/feature for regression tasks. + mutual_info: Mutual information between features and the target. SelectPercentile: Select features based on percentile of the highest scores. SelectKBest: Select features based on the k highest scores. SelectFpr: Select features based on a false positive rate test. @@ -607,6 +617,7 @@ class SelectFwe(_BaseFilter): f_classif: ANOVA F-value between label/feature for classification tasks. chi2: Chi-squared stats of non-negative features for classification tasks. f_regression: F-value between label/feature for regression tasks. + mutual_info: Mutual information between features and the target. SelectPercentile: Select features based on percentile of the highest scores. SelectKBest: Select features based on the k highest scores. SelectFpr: Select features based on a false positive rate test. @@ -639,7 +650,8 @@ class GenericUnivariateSelect(_BaseFilter): ---------- score_func : callable Function taking two arrays X and y, and returning a pair of arrays - (scores, pvalues). + (scores, pvalues). For modes 'percentile' or 'kbest' it can return + a single array scores. mode : {'percentile', 'k_best', 'fpr', 'fdr', 'fwe'} Feature selection mode. @@ -653,13 +665,14 @@ class GenericUnivariateSelect(_BaseFilter): Scores of features. pvalues_ : array-like, shape=(n_features,) - p-values of feature scores. + p-values of feature scores, None if `score_func` returned scores only. See also -------- f_classif: ANOVA F-value between label/feature for classification tasks. chi2: Chi-squared stats of non-negative features for classification tasks. f_regression: F-value between label/feature for regression tasks. + mutual_info: Mutual information between features and the target. SelectPercentile: Select features based on percentile of the highest scores. SelectKBest: Select features based on the k highest scores. SelectFpr: Select features based on a false positive rate test. From 54c078311ed056d124e49e9bef3c7df868ba0403 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Sun, 13 Dec 2015 03:47:32 +0500 Subject: [PATCH 02/25] DOC: Documentation update related to mutual_info --- doc/modules/classes.rst | 1 + doc/modules/feature_selection.rst | 16 +++++++++------- sklearn/feature_selection/mutual_info.py | 24 +++++++++++++++--------- 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 14969f2969713..9eca6e287123e 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -534,6 +534,7 @@ From text feature_selection.chi2 feature_selection.f_classif feature_selection.f_regression + feature_selection.mutual_info .. _gaussian_process_ref: diff --git a/doc/modules/feature_selection.rst b/doc/modules/feature_selection.rst index 679db067ebf8a..87444054741b2 100644 --- a/doc/modules/feature_selection.rst +++ b/doc/modules/feature_selection.rst @@ -67,8 +67,8 @@ as objects that implement the ``transform`` method: :class:`SelectFdr`, or family wise error :class:`SelectFwe`. * :class:`GenericUnivariateSelect` allows to perform univariate feature - selection with a configurable strategy. This allows to select the best - univariate selection strategy with hyper-parameter search estimator. + selection with a configurable strategy. This allows to select the best + univariate selection strategy with hyper-parameter search estimator. For instance, we can perform a :math:`\chi^2` test to the samples to retrieve only the two best features as follows: @@ -84,17 +84,19 @@ to retrieve only the two best features as follows: >>> X_new.shape (150, 2) -These objects take as input a scoring function that returns -univariate p-values: +These objects take as input a scoring function that returns univariate scores +and p-values (or only scores for :class:`SelectKBest` and +:class:`SelectPercentile`): - * For regression: :func:`f_regression` + * For regression: :func:`f_regression`, :func:`mutual_info` - * For classification: :func:`chi2` or :func:`f_classif` + * For classification: :func:`chi2`, :func:`f_classif`, :func:`mutual_info` .. topic:: Feature selection with sparse data If you use sparse data (i.e. data represented as sparse matrices), - only :func:`chi2` will deal with the data without making it dense. + :func:`chi2` and :func:`mutual_info` will deal with the data without making + it dense. .. warning:: diff --git a/sklearn/feature_selection/mutual_info.py b/sklearn/feature_selection/mutual_info.py index 3638712ef7e74..145d5d067040e 100644 --- a/sklearn/feature_selection/mutual_info.py +++ b/sklearn/feature_selection/mutual_info.py @@ -195,9 +195,17 @@ def _iterate_columns(X, columns=None): def mutual_info(X, y, discrete_features='auto', discrete_target=False, n_neighbors=3, copy=True, random_state=None): - """Estimate mutual information (MI) between the features and the target. + """Estimate mutual information between the features and the target. - Read more in the :ref:`User Guide `. + Mutual information (MI) [1]_ between two random variables is a non-negative + value, which measures the dependency between the variables. It is equal + to zero if and only if two random variables are independent, higher values + mean higher dependency. + + The function is capable of estimating MI between continuous and discrete + variables as described in [1]_ and [2]_. + It can be used for univariate features selection, read more in the + :ref:`User Guide `. Parameters ---------- @@ -219,7 +227,7 @@ def mutual_info(X, y, discrete_features='auto', discrete_target=False, copy : bool, default True Whether to make a copy of the given data. If set to False, the initial data will be overwritten. - random_state : int seed, RandomState instance, or None (default=None) + random_state : int seed, RandomState instance or None, default None The seed of the pseudo random number generator for adding small noise to continuous variables in order to remove repeated values. @@ -237,15 +245,13 @@ def mutual_info(X, y, discrete_features='auto', discrete_target=False, mark them as such. Also note, that treating a continuous variable as discrete and vice versa will usually give meaningless results, so be attentive about that. - 2. True mutual information can't be negative. If its estimate turned out - to be negative, it will be replaced by zero. + 2. True mutual information can't be negative. If its estimate turns out + to be negative, it is replaced by zero. References ---------- - .. [1] H. Peng, F. Long, and C. Ding, "Feature selection based on mutual - information: criteria of max-dependency, max-relevance, and - min-redundancy", IEEE Transactions on Pattern Analysis and Machine - Intelligence, Vol. 27, No. 8, pp. 1226-1238, 2005. + .. [1] `Mutual Information `_ + on Wikipedia. .. [2] A. Kraskov, H. Stogbauer and P. Grassberger, "Estimating mutual information". Phys. Rev. E 69, 2004. .. [3] B. C. Ross "Mutual Information between Discrete and Continuous From 0245fc8aec4d78087de67b723ce81d8551159a03 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Sun, 13 Dec 2015 03:56:18 +0500 Subject: [PATCH 03/25] MAINT: Use six.moves.zip in mutual_info --- sklearn/feature_selection/mutual_info.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/feature_selection/mutual_info.py b/sklearn/feature_selection/mutual_info.py index 145d5d067040e..73f113c5eff09 100644 --- a/sklearn/feature_selection/mutual_info.py +++ b/sklearn/feature_selection/mutual_info.py @@ -6,6 +6,7 @@ from scipy.sparse import issparse from scipy.special import digamma +from ..externals.six import moves from ..metrics.cluster.supervised import mutual_info_score from ..neighbors import NearestNeighbors from ..preprocessing import scale @@ -298,6 +299,6 @@ def mutual_info(X, y, discrete_features='auto', discrete_target=False, y += 1e-10 * np.maximum(1, np.mean(np.abs(y))) * rng.randn(n_samples) mi = [_compute_mi(x, y, discrete_feature, discrete_target) for - x, discrete_feature in zip(_iterate_columns(X), discrete_mask)] + x, discrete_feature in moves.zip(_iterate_columns(X), discrete_mask)] return np.array(mi) From c1aea3f868652e23eb01a1ebaf22ab0e77c4e9ee Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Tue, 15 Dec 2015 02:15:01 +0500 Subject: [PATCH 04/25] MAINT: Renamed module mutual_info to mutual_info_ --- sklearn/feature_selection/__init__.py | 4 ++-- .../{mutual_info.py => mutual_info_.py} | 13 +++++++------ sklearn/feature_selection/tests/test_mutual_info.py | 4 ++-- 3 files changed, 11 insertions(+), 10 deletions(-) rename sklearn/feature_selection/{mutual_info.py => mutual_info_.py} (96%) diff --git a/sklearn/feature_selection/__init__.py b/sklearn/feature_selection/__init__.py index e41fa12cd94e8..204b4604c8ebb 100644 --- a/sklearn/feature_selection/__init__.py +++ b/sklearn/feature_selection/__init__.py @@ -22,7 +22,7 @@ from .from_model import SelectFromModel -from .mutual_info import mutual_info +from .mutual_info_ import mutual_info __all__ = ['GenericUnivariateSelect', @@ -32,7 +32,7 @@ 'SelectFpr', 'SelectFwe', 'SelectKBest', - 'SelectFromModel' + 'SelectFromModel', 'SelectPercentile', 'VarianceThreshold', 'chi2', diff --git a/sklearn/feature_selection/mutual_info.py b/sklearn/feature_selection/mutual_info_.py similarity index 96% rename from sklearn/feature_selection/mutual_info.py rename to sklearn/feature_selection/mutual_info_.py index 73f113c5eff09..421477081e91a 100644 --- a/sklearn/feature_selection/mutual_info.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -200,10 +200,10 @@ def mutual_info(X, y, discrete_features='auto', discrete_target=False, Mutual information (MI) [1]_ between two random variables is a non-negative value, which measures the dependency between the variables. It is equal - to zero if and only if two random variables are independent, higher values - mean higher dependency. + to zero if and only if two random variables are independent, and higher + values mean higher dependency. - The function is capable of estimating MI between continuous and discrete + This function is capable of estimating MI between continuous and discrete variables as described in [1]_ and [2]_. It can be used for univariate features selection, read more in the :ref:`User Guide `. @@ -218,13 +218,14 @@ def mutual_info(X, y, discrete_features='auto', discrete_target=False, If bool, then determines whether to consider all features discrete or continuous. If array, then it should be either a boolean mask with shape (n_features,) or array with indices of discrete features. - If 'auto', assigned to False for dense `X` and to True for sparse `X`. + If 'auto', it is assigned to False for dense `X` and to True for + sparse `X`. discrete_target : bool, default False Whether to consider `y` as a discrete variable. n_neighbors : int, default 3 Number of neighbors to use for MI estimation for continuous variables, see [2]_ and [3]_. Higher values reduce variance of the estimation, but - could increase a bias. + could introduce a bias. copy : bool, default True Whether to make a copy of the given data. If set to False, the initial data will be overwritten. @@ -244,7 +245,7 @@ def mutual_info(X, y, discrete_features='auto', discrete_target=False, accurately. For example, pixel intensities of an image are discrete features (but hardly categorical) and you will get better results if mark them as such. Also note, that treating a continuous variable as - discrete and vice versa will usually give meaningless results, so be + discrete and vice versa will usually give incorrect results, so be attentive about that. 2. True mutual information can't be negative. If its estimate turns out to be negative, it is replaced by zero. diff --git a/sklearn/feature_selection/tests/test_mutual_info.py b/sklearn/feature_selection/tests/test_mutual_info.py index a4fe0390be212..13e1a7ba49b1b 100644 --- a/sklearn/feature_selection/tests/test_mutual_info.py +++ b/sklearn/feature_selection/tests/test_mutual_info.py @@ -6,7 +6,7 @@ from sklearn.utils.testing import (assert_array_equal, assert_almost_equal, assert_false, assert_true, assert_raises) -from sklearn.feature_selection.mutual_info import mutual_info, _compute_mi +from sklearn.feature_selection.mutual_info_ import mutual_info, _compute_mi class TestMIComputation(object): @@ -106,7 +106,7 @@ def test_discrete(self): def test_continuous(self): # We generate sample from multivariate normal distribution, using - # transformation from initially uncorrelated variables. The first + # transformation from initially uncorrelated variables. The zero # variables after transformation is selected as the target vector, # it has the strongest correlation with the variable 2, and # the weakest correlation with the variable 1. From 689ed0dbba7f2abc6c076b235958f09f7b4953e5 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Sun, 10 Jan 2016 22:01:36 +0500 Subject: [PATCH 05/25] DOC: Example for mutual_information --- doc/modules/feature_selection.rst | 4 +- .../feature_selection/plot_f_test_vs_mi.py | 47 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 examples/feature_selection/plot_f_test_vs_mi.py diff --git a/doc/modules/feature_selection.rst b/doc/modules/feature_selection.rst index 87444054741b2..fcff03c2e0631 100644 --- a/doc/modules/feature_selection.rst +++ b/doc/modules/feature_selection.rst @@ -105,7 +105,9 @@ and p-values (or only scores for :class:`SelectKBest` and .. topic:: Examples: - :ref:`example_feature_selection_plot_feature_selection.py` + * :ref:`example_feature_selection_plot_feature_selection.py` + + * :ref:`example_feature_selection_plot_f_test_vs_mi.py` .. _rfe: diff --git a/examples/feature_selection/plot_f_test_vs_mi.py b/examples/feature_selection/plot_f_test_vs_mi.py new file mode 100644 index 0000000000000..d5a07d1439507 --- /dev/null +++ b/examples/feature_selection/plot_f_test_vs_mi.py @@ -0,0 +1,47 @@ +""" +=========================================== +Comparison of F-test and mutual information +=========================================== + +This example illustrates the differences between univariate F-test statistics +and mutual information. + +We consider 3 features x_1, x_2, x_3 distributed uniformly over [0, 1], the +target depends on them as follows: + +y = x_1 + sin(6 * pi * x_2) + 0.1 * N(0, 1), that is the third features is completely irrelevant. + +The code below plots the dependency of y against individual x_i and normalized +values of univariate F-tests statistics and mutual information. + +As F-test captures only linear dependency, it rates x_1 as the most +discriminative feature. On the other hand, mutual information can capture any +kind of dependency between variables and it rates x_2 as the most +discriminative feature, which probably agrees better with our intuitive +perception for this example. Both methods correctly marks x_3 as irrelevant. +""" +print(__doc__) + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.feature_selection import mutual_info, f_regression + +X = np.random.rand(1000, 3) +y = X[:, 0] + np.sin(6 * np.pi * X[:, 1]) + 0.1 * np.random.randn(1000) + +f_test, _ = f_regression(X, y) +f_test /= np.max(f_test) + +mi = mutual_info(X, y) +mi /= np.max(mi) + +plt.figure(figsize=(15, 5)) +for i in range(3): + plt.subplot(1, 3, i + 1) + plt.scatter(X[:, i], y) + plt.xlabel("$x_{}$".format(i + 1), fontsize=14) + if i == 0: + plt.ylabel("$y$", fontsize=14) + plt.title("F-test={:.2f}, MI={:.2f}".format(f_test[i], mi[i]), + fontsize=16) +plt.show() From 835102ab6a4bae273e0641d3634e4849aa36a256 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Tue, 12 Jan 2016 22:43:23 +0500 Subject: [PATCH 06/25] API: Split mutual_info into _regression and _classif --- doc/modules/classes.rst | 3 +- doc/modules/feature_selection.rst | 8 +- .../feature_selection/plot_f_test_vs_mi.py | 4 +- sklearn/feature_selection/__init__.py | 5 +- sklearn/feature_selection/mutual_info_.py | 171 +++++++++++++++--- .../tests/test_feature_select.py | 25 ++- .../tests/test_mutual_info.py | 37 ++-- .../feature_selection/univariate_selection.py | 22 ++- 8 files changed, 195 insertions(+), 80 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 9eca6e287123e..5023564df7c55 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -534,7 +534,8 @@ From text feature_selection.chi2 feature_selection.f_classif feature_selection.f_regression - feature_selection.mutual_info + feature_selection.mutual_info_classif + feature_selection.mutual_info_regression .. _gaussian_process_ref: diff --git a/doc/modules/feature_selection.rst b/doc/modules/feature_selection.rst index fcff03c2e0631..2b70c25af82a9 100644 --- a/doc/modules/feature_selection.rst +++ b/doc/modules/feature_selection.rst @@ -88,15 +88,15 @@ These objects take as input a scoring function that returns univariate scores and p-values (or only scores for :class:`SelectKBest` and :class:`SelectPercentile`): - * For regression: :func:`f_regression`, :func:`mutual_info` + * For regression: :func:`f_regression`, :func:`mutual_info_regression` - * For classification: :func:`chi2`, :func:`f_classif`, :func:`mutual_info` + * For classification: :func:`chi2`, :func:`f_classif`, :func:`mutual_info_classif` .. topic:: Feature selection with sparse data If you use sparse data (i.e. data represented as sparse matrices), - :func:`chi2` and :func:`mutual_info` will deal with the data without making - it dense. + :func:`chi2`, :func:`mutual_info_regression`, :func:`mutual_info_classif` + will deal with the data without making it dense. .. warning:: diff --git a/examples/feature_selection/plot_f_test_vs_mi.py b/examples/feature_selection/plot_f_test_vs_mi.py index d5a07d1439507..ea178be9e99c9 100644 --- a/examples/feature_selection/plot_f_test_vs_mi.py +++ b/examples/feature_selection/plot_f_test_vs_mi.py @@ -24,7 +24,7 @@ import numpy as np import matplotlib.pyplot as plt -from sklearn.feature_selection import mutual_info, f_regression +from sklearn.feature_selection import f_regression, mutual_info_regression X = np.random.rand(1000, 3) y = X[:, 0] + np.sin(6 * np.pi * X[:, 1]) + 0.1 * np.random.randn(1000) @@ -32,7 +32,7 @@ f_test, _ = f_regression(X, y) f_test /= np.max(f_test) -mi = mutual_info(X, y) +mi = mutual_info_regression(X, y) mi /= np.max(mi) plt.figure(figsize=(15, 5)) diff --git a/sklearn/feature_selection/__init__.py b/sklearn/feature_selection/__init__.py index 204b4604c8ebb..ffa392b5b26db 100644 --- a/sklearn/feature_selection/__init__.py +++ b/sklearn/feature_selection/__init__.py @@ -22,7 +22,7 @@ from .from_model import SelectFromModel -from .mutual_info_ import mutual_info +from .mutual_info_ import mutual_info_regression, mutual_info_classif __all__ = ['GenericUnivariateSelect', @@ -39,4 +39,5 @@ 'f_classif', 'f_oneway', 'f_regression', - 'mutual_info'] + 'mutual_info_classif', + 'mutual_info_regression'] diff --git a/sklearn/feature_selection/mutual_info_.py b/sklearn/feature_selection/mutual_info_.py index 421477081e91a..d0903dc5e5120 100644 --- a/sklearn/feature_selection/mutual_info_.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -194,20 +194,10 @@ def _iterate_columns(X, columns=None): yield _get_column(X, i) -def mutual_info(X, y, discrete_features='auto', discrete_target=False, - n_neighbors=3, copy=True, random_state=None): +def _estimate_mi(X, y, discrete_features='auto', discrete_target=False, + n_neighbors=3, copy=True, random_state=None): """Estimate mutual information between the features and the target. - Mutual information (MI) [1]_ between two random variables is a non-negative - value, which measures the dependency between the variables. It is equal - to zero if and only if two random variables are independent, and higher - values mean higher dependency. - - This function is capable of estimating MI between continuous and discrete - variables as described in [1]_ and [2]_. - It can be used for univariate features selection, read more in the - :ref:`User Guide `. - Parameters ---------- X : array_like or sparse matrix, shape (n_samples, n_features) @@ -224,7 +214,7 @@ def mutual_info(X, y, discrete_features='auto', discrete_target=False, Whether to consider `y` as a discrete variable. n_neighbors : int, default 3 Number of neighbors to use for MI estimation for continuous variables, - see [2]_ and [3]_. Higher values reduce variance of the estimation, but + see [1]_ and [2]_. Higher values reduce variance of the estimation, but could introduce a bias. copy : bool, default True Whether to make a copy of the given data. If set to False, the initial @@ -237,26 +227,13 @@ def mutual_info(X, y, discrete_features='auto', discrete_target=False, ------- mi : ndarray, shape (n_features,) Estimated mutual information between each feature and the target. - - Notes - ----- - 1. Terms "discrete feature" and "discrete target" are used instead of - naming them as "categorical", because it describes the essence more - accurately. For example, pixel intensities of an image are discrete - features (but hardly categorical) and you will get better results if - mark them as such. Also note, that treating a continuous variable as - discrete and vice versa will usually give incorrect results, so be - attentive about that. - 2. True mutual information can't be negative. If its estimate turns out - to be negative, it is replaced by zero. + A negative value will be replaced by 0. References ---------- - .. [1] `Mutual Information `_ - on Wikipedia. - .. [2] A. Kraskov, H. Stogbauer and P. Grassberger, "Estimating mutual + .. [1] A. Kraskov, H. Stogbauer and P. Grassberger, "Estimating mutual information". Phys. Rev. E 69, 2004. - .. [3] B. C. Ross "Mutual Information between Discrete and Continuous + .. [2] B. C. Ross "Mutual Information between Discrete and Continuous Data Sets". PLoS ONE 9(2), 2014. """ X, y = check_X_y(X, y, accept_sparse='csc', y_numeric=not discrete_target) @@ -283,7 +260,7 @@ def mutual_info(X, y, discrete_features='auto', discrete_target=False, if copy: X = X.copy() - if not discrete_target: + if not discrete_target and np.any(continuous_mask): X[:, continuous_mask] = scale(X[:, continuous_mask], with_mean=False, copy=False) @@ -303,3 +280,137 @@ def mutual_info(X, y, discrete_features='auto', discrete_target=False, x, discrete_feature in moves.zip(_iterate_columns(X), discrete_mask)] return np.array(mi) + + +def mutual_info_regression(X, y, discrete_features='auto', n_neighbors=3, + copy=True, random_state=None): + """Estimate mutual information for a continuous target variable. + + Mutual information (MI) [1]_ between two random variables is a non-negative + value, which measures the dependency between the variables. It is equal + to zero if and only if two random variables are independent, and higher + values mean higher dependency. + + This function relies on the algorithms of MI estimation described in [2]_ + and [3]_. + + It can be used for univariate features selection, read more in the + :ref:`User Guide `. + + Parameters + ---------- + X : array_like or sparse matrix, shape (n_samples, n_features) + Feature matrix. + y : array_like, shape (n_samples,) + Target vector. + discrete_features : {'auto', bool, array_like}, default 'auto' + If bool, then determines whether to consider all features discrete + or continuous. If array, then it should be either a boolean mask + with shape (n_features,) or array with indices of discrete features. + If 'auto', it is assigned to False for dense `X` and to True for + sparse `X`. + n_neighbors : int, default 3 + Number of neighbors to use for MI estimation for continuous variables, + see [2]_ and [3]_. Higher values reduce variance of the estimation, but + could introduce a bias. + copy : bool, default True + Whether to make a copy of the given data. If set to False, the initial + data will be overwritten. + random_state : int seed, RandomState instance or None, default None + The seed of the pseudo random number generator for adding small noise + to continuous variables in order to remove repeated values. + + Returns + ------- + mi : ndarray, shape (n_features,) + Estimated mutual information between each feature and the target. + + Notes + ----- + 1. The term "discrete features" is used instead of naming them + "categorical", because it describes the essence more accurately. + For example, pixel intensities of an image are discrete features + (but hardly categorical) and you will get better results if mark them + as such. Also note, that treating a continuous variable as discrete and + vice versa will usually give incorrect results, so be attentive about that. + 2. True mutual information can't be negative. If its estimate turns out + to be negative, it is replaced by zero. + + References + ---------- + .. [1] `Mutual Information `_ + on Wikipedia. + .. [2] A. Kraskov, H. Stogbauer and P. Grassberger, "Estimating mutual + information". Phys. Rev. E 69, 2004. + .. [3] B. C. Ross "Mutual Information between Discrete and Continuous + Data Sets". PLoS ONE 9(2), 2014. + """ + return _estimate_mi(X, y, discrete_features, False, n_neighbors, + copy, random_state) + + +def mutual_info_classif(X, y, discrete_features='auto', n_neighbors=3, + copy=True, random_state=None): + """Estimate mutual information for a discrete target variable. + + Mutual information (MI) [1]_ between two random variables is a non-negative + value, which measures the dependency between the variables. It is equal + to zero if and only if two random variables are independent, and higher + values mean higher dependency. + + This function relies on the algorithms of MI estimation described in [2]_ + and [3]_. + + It can be used for univariate features selection, read more in the + :ref:`User Guide `. + + Parameters + ---------- + X : array_like or sparse matrix, shape (n_samples, n_features) + Feature matrix. + y : array_like, shape (n_samples,) + Target vector. + discrete_features : {'auto', bool, array_like}, default 'auto' + If bool, then determines whether to consider all features discrete + or continuous. If array, then it should be either a boolean mask + with shape (n_features,) or array with indices of discrete features. + If 'auto', it is assigned to False for dense `X` and to True for + sparse `X`. + n_neighbors : int, default 3 + Number of neighbors to use for MI estimation for continuous variables, + see [2]_ and [3]_. Higher values reduce variance of the estimation, but + could introduce a bias. + copy : bool, default True + Whether to make a copy of the given data. If set to False, the initial + data will be overwritten. + random_state : int seed, RandomState instance or None, default None + The seed of the pseudo random number generator for adding small noise + to continuous variables in order to remove repeated values. + + Returns + ------- + mi : ndarray, shape (n_features,) + Estimated mutual information between each feature and the target. + + Notes + ----- + 1. The term "discrete features" is used instead of naming them + "categorical", because it describes the essence more accurately. + For example, pixel intensities of an image are discrete features + (but hardly categorical) and you will get better results if mark them + as such. Also note, that treating a continuous variable as discrete and + vice versa will usually give incorrect results, so be attentive about that. + 2. True mutual information can't be negative. If its estimate turns out + to be negative, it is replaced by zero. + + References + ---------- + .. [1] `Mutual Information `_ + on Wikipedia. + .. [2] A. Kraskov, H. Stogbauer and P. Grassberger, "Estimating mutual + information". Phys. Rev. E 69, 2004. + .. [3] B. C. Ross "Mutual Information between Discrete and Continuous + Data Sets". PLoS ONE 9(2), 2014. + """ + return _estimate_mi(X, y, discrete_features, True, n_neighbors, + copy, random_state) diff --git a/sklearn/feature_selection/tests/test_feature_select.py b/sklearn/feature_selection/tests/test_feature_select.py index 0cbd56c014532..e23b9a3c20843 100644 --- a/sklearn/feature_selection/tests/test_feature_select.py +++ b/sklearn/feature_selection/tests/test_feature_select.py @@ -25,8 +25,9 @@ from sklearn.datasets.samples_generator import (make_classification, make_regression) from sklearn.feature_selection import ( - chi2, mutual_info, f_classif, f_oneway, f_regression, SelectPercentile, - SelectKBest, SelectFpr, SelectFdr, SelectFwe, GenericUnivariateSelect) + chi2, f_classif, f_oneway, f_regression, mutual_info_classif, + mutual_info_regression, SelectPercentile, SelectKBest, SelectFpr, + SelectFdr, SelectFwe, GenericUnivariateSelect) ############################################################################## @@ -564,13 +565,11 @@ def test_mutual_info_classification(): n_clusters_per_class=1, flip_y=0.0, class_sep=10, shuffle=False, random_state=0) - score_func = lambda X, y: mutual_info(X, y, discrete_target=True) - # Test in KBest mode. - univariate_filter = SelectKBest(score_func, k=5) + univariate_filter = SelectKBest(mutual_info_classif, k=5) X_r = univariate_filter.fit(X, y).transform(X) X_r2 = GenericUnivariateSelect( - score_func, mode='k_best', param=5).fit(X, y).transform(X) + mutual_info_classif, mode='k_best', param=5).fit(X, y).transform(X) assert_array_equal(X_r, X_r2) support = univariate_filter.get_support() gtruth = np.zeros(20) @@ -578,10 +577,10 @@ def test_mutual_info_classification(): assert_array_equal(support, gtruth) # Test in Percentile mode. - univariate_filter = SelectPercentile(score_func, percentile=25) + univariate_filter = SelectPercentile(mutual_info_classif, percentile=25) X_r = univariate_filter.fit(X, y).transform(X) X_r2 = GenericUnivariateSelect( - score_func, mode='percentile', param=25).fit(X, y).transform(X) + mutual_info_classif, mode='percentile', param=25).fit(X, y).transform(X) assert_array_equal(X_r, X_r2) support = univariate_filter.get_support() gtruth = np.zeros(20) @@ -594,11 +593,11 @@ def test_mutual_info_regression(): shuffle=False, random_state=0, noise=10) # Test in KBest mode. - univariate_filter = SelectKBest(mutual_info, k=5) + univariate_filter = SelectKBest(mutual_info_regression, k=5) X_r = univariate_filter.fit(X, y).transform(X) assert_best_scores_kept(univariate_filter) X_r2 = GenericUnivariateSelect( - mutual_info, mode='k_best', param=5).fit(X, y).transform(X) + mutual_info_regression, mode='k_best', param=5).fit(X, y).transform(X) assert_array_equal(X_r, X_r2) support = univariate_filter.get_support() gtruth = np.zeros(20) @@ -606,10 +605,10 @@ def test_mutual_info_regression(): assert_array_equal(support, gtruth) # Test in Percentile mode. - univariate_filter = SelectPercentile(mutual_info, percentile=25) + univariate_filter = SelectPercentile(mutual_info_regression, percentile=25) X_r = univariate_filter.fit(X, y).transform(X) - X_r2 = GenericUnivariateSelect( - mutual_info, mode='percentile', param=25).fit(X, y).transform(X) + X_r2 = GenericUnivariateSelect(mutual_info_regression, mode='percentile', + param=25).fit(X, y).transform(X) assert_array_equal(X_r, X_r2) support = univariate_filter.get_support() gtruth = np.zeros(20) diff --git a/sklearn/feature_selection/tests/test_mutual_info.py b/sklearn/feature_selection/tests/test_mutual_info.py index 13e1a7ba49b1b..6c7bd7cbde5d5 100644 --- a/sklearn/feature_selection/tests/test_mutual_info.py +++ b/sklearn/feature_selection/tests/test_mutual_info.py @@ -6,7 +6,8 @@ from sklearn.utils.testing import (assert_array_equal, assert_almost_equal, assert_false, assert_true, assert_raises) -from sklearn.feature_selection.mutual_info_ import mutual_info, _compute_mi +from sklearn.feature_selection.mutual_info_ import ( + mutual_info_regression, mutual_info_classif, _compute_mi) class TestMIComputation(object): @@ -101,7 +102,7 @@ def test_discrete(self): # Here X[:, 0] is the most informative feature, and X[:, 1] is weakly # informative. - mi = mutual_info(X, y, discrete_features=True, discrete_target=True) + mi = mutual_info_classif(X, y, discrete_features=True) assert_array_equal(np.argsort(-mi), np.array([0, 2, 1])) def test_continuous(self): @@ -124,7 +125,7 @@ def test_continuous(self): X = Z[:, 1:] y = Z[:, 0] - mi = mutual_info(X, y, random_state=0) + mi = mutual_info_regression(X, y, random_state=0) assert_array_equal(np.argsort(-mi), np.array([1, 2, 0])) def test_mixed(self): @@ -136,8 +137,7 @@ def test_mixed(self): y = ((0.5 * X[:, 0] + X[:, 2]) > 0.5).astype(int) X[:, 2] = X[:, 2] > 0.5 - mi = mutual_info(X, y, discrete_features=[2], discrete_target=True, - random_state=0) + mi = mutual_info_classif(X, y, discrete_features=[2], random_state=0) assert_array_equal(np.argsort(-mi), [2, 0, 1]) def test_discrete_features_option(self): @@ -145,25 +145,24 @@ def test_discrete_features_option(self): [1, 1, 0], [2, 0, 1], [2, 0, 1], - [2, 0, 1]]) - y = np.array([0, 1, 2, 2, 1]) + [2, 0, 1]], dtype=float) + y = np.array([0, 1, 2, 2, 1], dtype=float) X_csr = csr_matrix(X) - assert_raises(ValueError, mutual_info, X_csr, y, - discrete_features=False) + for mutual_info in (mutual_info_regression, mutual_info_classif): + assert_raises(ValueError, mutual_info_regression, X_csr, y, + discrete_features=False) - mi_1 = mutual_info(X, y, discrete_features='auto', - discrete_target=True, random_state=0) - mi_2 = mutual_info(X, y, discrete_features=False, - discrete_target=True, random_state=0) + mi_1 = mutual_info(X, y, discrete_features='auto', random_state=0) + mi_2 = mutual_info(X, y, discrete_features=False, random_state=0) - mi_3 = mutual_info(X_csr, y, discrete_features='auto', - discrete_target=True) - mi_4 = mutual_info(X_csr, y, discrete_features=True, - discrete_target=True) + mi_3 = mutual_info(X_csr, y, discrete_features='auto', + random_state=0) + mi_4 = mutual_info(X_csr, y, discrete_features=True, + random_state=0) - assert_array_equal(mi_1, mi_2) - assert_array_equal(mi_3, mi_4) + assert_array_equal(mi_1, mi_2) + assert_array_equal(mi_3, mi_4) assert_false(np.allclose(mi_1, mi_3)) diff --git a/sklearn/feature_selection/univariate_selection.py b/sklearn/feature_selection/univariate_selection.py index b74c5cd5928e3..9ec3fe17cf961 100644 --- a/sklearn/feature_selection/univariate_selection.py +++ b/sklearn/feature_selection/univariate_selection.py @@ -295,7 +295,7 @@ class _BaseFilter(BaseEstimator, SelectorMixin): ---------- score_func : callable Function taking two arrays X and y, and returning a pair of arrays - (scores, pvalues) or a single array scores. + (scores, pvalues) or a single array with scores. """ def __init__(self, score_func): @@ -354,7 +354,7 @@ class SelectPercentile(_BaseFilter): ---------- score_func : callable Function taking two arrays X and y, and returning a pair of arrays - (scores, pvalues) or a single array scores. + (scores, pvalues) or a single array with scores. percentile : int, optional, default=10 Percent of features to keep. @@ -375,9 +375,10 @@ class SelectPercentile(_BaseFilter): See also -------- f_classif: ANOVA F-value between label/feature for classification tasks. + mutual_info_classif: Mutual information for a discrete target. chi2: Chi-squared stats of non-negative features for classification tasks. f_regression: F-value between label/feature for regression tasks. - mutual_info: Mutual information between features and the target. + mutual_info_regression: Mutual information for a continuous target. SelectKBest: Select features based on the k highest scores. SelectFpr: Select features based on a false positive rate test. SelectFdr: Select features based on an estimated false discovery rate. @@ -424,7 +425,7 @@ class SelectKBest(_BaseFilter): ---------- score_func : callable Function taking two arrays X and y, and returning a pair of arrays - (scores, pvalues). + (scores, pvalues) or a single scores array. k : int or "all", optional, default=10 Number of top features to select. @@ -446,9 +447,10 @@ class SelectKBest(_BaseFilter): See also -------- f_classif: ANOVA F-value between label/feature for classification tasks. + mutual_info_classif: Mutual information for a discrete target. chi2: Chi-squared stats of non-negative features for classification tasks. f_regression: F-value between label/feature for regression tasks. - mutual_info: Mutual information between features and the target. + mutual_info_regression: Mutual information for a continious target. SelectPercentile: Select features based on percentile of the highest scores. SelectFpr: Select features based on a false positive rate test. SelectFdr: Select features based on an estimated false discovery rate. @@ -512,8 +514,9 @@ class SelectFpr(_BaseFilter): -------- f_classif: ANOVA F-value between label/feature for classification tasks. chi2: Chi-squared stats of non-negative features for classification tasks. + mutual_info_classif: f_regression: F-value between label/feature for regression tasks. - mutual_info: Mutual information between features and the target. + mutual_info_regression: Mutual information between features and the target. SelectPercentile: Select features based on percentile of the highest scores. SelectKBest: Select features based on the k highest scores. SelectFdr: Select features based on an estimated false discovery rate. @@ -564,9 +567,10 @@ class SelectFdr(_BaseFilter): See also -------- f_classif: ANOVA F-value between label/feature for classification tasks. + mutual_info_classif: Mutual information for a discrete target. chi2: Chi-squared stats of non-negative features for classification tasks. f_regression: F-value between label/feature for regression tasks. - mutual_info: Mutual information between features and the target. + mutual_info_regression: Mutual information for a contnuous target. SelectPercentile: Select features based on percentile of the highest scores. SelectKBest: Select features based on the k highest scores. SelectFpr: Select features based on a false positive rate test. @@ -617,7 +621,6 @@ class SelectFwe(_BaseFilter): f_classif: ANOVA F-value between label/feature for classification tasks. chi2: Chi-squared stats of non-negative features for classification tasks. f_regression: F-value between label/feature for regression tasks. - mutual_info: Mutual information between features and the target. SelectPercentile: Select features based on percentile of the highest scores. SelectKBest: Select features based on the k highest scores. SelectFpr: Select features based on a false positive rate test. @@ -670,9 +673,10 @@ class GenericUnivariateSelect(_BaseFilter): See also -------- f_classif: ANOVA F-value between label/feature for classification tasks. + mutual_info_classif: Mutual information for a discrete target. chi2: Chi-squared stats of non-negative features for classification tasks. f_regression: F-value between label/feature for regression tasks. - mutual_info: Mutual information between features and the target. + mutual_info_regression: Mutual information for a continuous target. SelectPercentile: Select features based on percentile of the highest scores. SelectKBest: Select features based on the k highest scores. SelectFpr: Select features based on a false positive rate test. From 8394c1bfccba77f91947c38e41cfa8987f73df25 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Fri, 15 Jan 2016 21:28:56 +0500 Subject: [PATCH 07/25] MAINT: Add blank lines between parameters in mutual_info_.py --- sklearn/feature_selection/mutual_info_.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/sklearn/feature_selection/mutual_info_.py b/sklearn/feature_selection/mutual_info_.py index d0903dc5e5120..e4a08b55a9a41 100644 --- a/sklearn/feature_selection/mutual_info_.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -20,7 +20,8 @@ def _compute_mi_cc(x, y, n_neighbors): Parameters ---------- x, y : ndarray - Samples from random variables, 1-d arrays of identical shape. + Samples of random variables, 1-d arrays of identical shape. + n_neighbors : int Number of nearest neighbors to search for each point, see [1]_. @@ -77,8 +78,10 @@ def _compute_mi_cd(c, d, n_neighbors): ---------- c : ndarray Samples from a continuous random variable. + d : ndarray Samples from a discrete random variable. + n_neighbors : int Number of nearest neighbors to search for each point, see [1]_. @@ -154,6 +157,7 @@ def _get_column(X, i): ---------- X : ndarray or csc_matrix, shape (n_samples, n_features) Matrix from which to get a column. + i : int Column index. @@ -179,6 +183,7 @@ def _iterate_columns(X, columns=None): ---------- X : ndarray or csc_matrix, shape (n_samples, n_features) Matrix over which to iterate. + columns : iterable or None, default None Indices of columns to iterate over. If None, iterate over all columns. @@ -202,23 +207,29 @@ def _estimate_mi(X, y, discrete_features='auto', discrete_target=False, ---------- X : array_like or sparse matrix, shape (n_samples, n_features) Feature matrix. + y : array_like, shape (n_samples,) Target vector. + discrete_features : {'auto', bool, array_like}, default 'auto' If bool, then determines whether to consider all features discrete or continuous. If array, then it should be either a boolean mask with shape (n_features,) or array with indices of discrete features. If 'auto', it is assigned to False for dense `X` and to True for sparse `X`. + discrete_target : bool, default False Whether to consider `y` as a discrete variable. + n_neighbors : int, default 3 Number of neighbors to use for MI estimation for continuous variables, see [1]_ and [2]_. Higher values reduce variance of the estimation, but could introduce a bias. + copy : bool, default True Whether to make a copy of the given data. If set to False, the initial data will be overwritten. + random_state : int seed, RandomState instance or None, default None The seed of the pseudo random number generator for adding small noise to continuous variables in order to remove repeated values. @@ -301,21 +312,26 @@ def mutual_info_regression(X, y, discrete_features='auto', n_neighbors=3, ---------- X : array_like or sparse matrix, shape (n_samples, n_features) Feature matrix. + y : array_like, shape (n_samples,) Target vector. + discrete_features : {'auto', bool, array_like}, default 'auto' If bool, then determines whether to consider all features discrete or continuous. If array, then it should be either a boolean mask with shape (n_features,) or array with indices of discrete features. If 'auto', it is assigned to False for dense `X` and to True for sparse `X`. + n_neighbors : int, default 3 Number of neighbors to use for MI estimation for continuous variables, see [2]_ and [3]_. Higher values reduce variance of the estimation, but could introduce a bias. + copy : bool, default True Whether to make a copy of the given data. If set to False, the initial data will be overwritten. + random_state : int seed, RandomState instance or None, default None The seed of the pseudo random number generator for adding small noise to continuous variables in order to remove repeated values. @@ -368,21 +384,26 @@ def mutual_info_classif(X, y, discrete_features='auto', n_neighbors=3, ---------- X : array_like or sparse matrix, shape (n_samples, n_features) Feature matrix. + y : array_like, shape (n_samples,) Target vector. + discrete_features : {'auto', bool, array_like}, default 'auto' If bool, then determines whether to consider all features discrete or continuous. If array, then it should be either a boolean mask with shape (n_features,) or array with indices of discrete features. If 'auto', it is assigned to False for dense `X` and to True for sparse `X`. + n_neighbors : int, default 3 Number of neighbors to use for MI estimation for continuous variables, see [2]_ and [3]_. Higher values reduce variance of the estimation, but could introduce a bias. + copy : bool, default True Whether to make a copy of the given data. If set to False, the initial data will be overwritten. + random_state : int seed, RandomState instance or None, default None The seed of the pseudo random number generator for adding small noise to continuous variables in order to remove repeated values. From ad2f5f5275927ab63a3d5cca1fe5e6087ba4bbda Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Fri, 15 Jan 2016 21:44:17 +0500 Subject: [PATCH 08/25] MAINT: Add check_classification_targets to mutual_info_classif --- sklearn/feature_selection/mutual_info_.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/feature_selection/mutual_info_.py b/sklearn/feature_selection/mutual_info_.py index e4a08b55a9a41..394de7997f884 100644 --- a/sklearn/feature_selection/mutual_info_.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -12,6 +12,7 @@ from ..preprocessing import scale from ..utils import check_random_state from ..utils.validation import check_X_y +from ..utils.multiclass import check_classification_targets def _compute_mi_cc(x, y, n_neighbors): @@ -433,5 +434,6 @@ def mutual_info_classif(X, y, discrete_features='auto', n_neighbors=3, .. [3] B. C. Ross "Mutual Information between Discrete and Continuous Data Sets". PLoS ONE 9(2), 2014. """ + check_classification_targets(y) return _estimate_mi(X, y, discrete_features, True, n_neighbors, copy, random_state) From ffc4fe91dae1aed39fcb8548f53ba434d8858207 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Fri, 15 Jan 2016 22:20:32 +0500 Subject: [PATCH 09/25] TST: Change tolerance checks in test_mutual_info.py --- sklearn/feature_selection/tests/test_mutual_info.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/feature_selection/tests/test_mutual_info.py b/sklearn/feature_selection/tests/test_mutual_info.py index 6c7bd7cbde5d5..ad9942e27ce09 100644 --- a/sklearn/feature_selection/tests/test_mutual_info.py +++ b/sklearn/feature_selection/tests/test_mutual_info.py @@ -48,11 +48,11 @@ def test_cc(self): x, y = Z[:, 0], Z[:, 1] - # Theory and computed values won't be very close, assert that relative - # error is less than 10%. + # Theory and computed values won't be very close, assert that the + # first figures after decimal point match. for n_neighbors in [3, 5, 7]: I_computed = _compute_mi(x, y, False, False, n_neighbors) - assert_true(np.abs(I_computed - I_theory) < 0.1 * I_theory) + assert_almost_equal(I_computed, I_theory, 1) def test_cd(self): # To test define a joint distribution as follows: @@ -84,10 +84,10 @@ def test_cd(self): I_theory = -0.5 * ((1 - p) * np.log(0.5 * (1 - p)) + p * np.log(0.5 * p) + np.log(0.5)) - np.log(2) - # Again assert that relative error is less than 10%. + # Assert the same tolerance. for n_neighbors in [3, 5, 7]: I_computed = _compute_mi(x, y, True, False, n_neighbors) - assert_true(np.abs(I_computed - I_theory) < 0.1 * I_theory) + assert_almost_equal(I_computed, I_theory, 1) class TestMutualInfo(object): From 824dda37671fe32fd76e5790c4857c1c28f21da3 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Fri, 15 Jan 2016 22:24:41 +0500 Subject: [PATCH 10/25] MAINT: Small changes to plot_f_test_vs_mi.py --- examples/feature_selection/plot_f_test_vs_mi.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/feature_selection/plot_f_test_vs_mi.py b/examples/feature_selection/plot_f_test_vs_mi.py index ea178be9e99c9..7917f19962dc6 100644 --- a/examples/feature_selection/plot_f_test_vs_mi.py +++ b/examples/feature_selection/plot_f_test_vs_mi.py @@ -26,6 +26,7 @@ import matplotlib.pyplot as plt from sklearn.feature_selection import f_regression, mutual_info_regression +np.random.seed(0) X = np.random.rand(1000, 3) y = X[:, 0] + np.sin(6 * np.pi * X[:, 1]) + 0.1 * np.random.randn(1000) @@ -45,3 +46,4 @@ plt.title("F-test={:.2f}, MI={:.2f}".format(f_test[i], mi[i]), fontsize=16) plt.show() + From 051d3a285d75b5463a0fbffa385c4d1a291c6a5b Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Sat, 16 Jan 2016 03:03:00 +0500 Subject: [PATCH 11/25] MAINT: Slightly improve logic of discrete-continuous MI estimation --- sklearn/feature_selection/mutual_info_.py | 26 ++++++++++++------- .../tests/test_mutual_info.py | 20 +++++++++++++- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/sklearn/feature_selection/mutual_info_.py b/sklearn/feature_selection/mutual_info_.py index 394de7997f884..b165d4342567b 100644 --- a/sklearn/feature_selection/mutual_info_.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -104,33 +104,41 @@ def _compute_mi_cd(c, d, n_neighbors): .. [1] B. C. Ross "Mutual Information between Discrete and Continuous Data Sets". PLoS ONE 9(2), 2014. """ + n_samples = c.shape[0] c = c.reshape((-1, 1)) - n_samples = c.size - - nn = NearestNeighbors(n_neighbors=n_neighbors) radius = np.empty(n_samples) label_counts = np.empty(n_samples) + k_all = np.empty(n_samples) + nn = NearestNeighbors() for label in np.unique(d): mask = d == label count = np.sum(mask) if count > 1: - nn.set_params(n_neighbors=min(n_neighbors, count - 1)) + k = min(n_neighbors, count - 1) + nn.set_params(n_neighbors=k) nn.fit(c[mask]) r = nn.kneighbors()[0] radius[mask] = np.nextafter(r[:, -1], 0) - else: - radius[mask] = 0 + k_all[mask] = k label_counts[mask] = count + # Ignore points with unique labels. + mask = label_counts > 1 + n_samples = np.sum(mask) + label_counts = label_counts[mask] + k_all = k_all[mask] + c = c[mask] + radius = radius[mask] + nn.set_params(algorithm='kd_tree') nn.fit(c) ind = nn.radius_neighbors(radius=radius, return_distance=False) - neighbor_counts = np.array([i.size for i in ind]) + m_all = np.array([i.size for i in ind]) - mi = (digamma(n_samples) + digamma(n_neighbors) - + mi = (digamma(n_samples) + np.mean(digamma(k_all)) - np.mean(digamma(label_counts)) - - np.mean(digamma(neighbor_counts + 1))) + np.mean(digamma(m_all + 1))) return max(0, mi) diff --git a/sklearn/feature_selection/tests/test_mutual_info.py b/sklearn/feature_selection/tests/test_mutual_info.py index ad9942e27ce09..a68d4ce0fd4be 100644 --- a/sklearn/feature_selection/tests/test_mutual_info.py +++ b/sklearn/feature_selection/tests/test_mutual_info.py @@ -5,7 +5,7 @@ from scipy.sparse import csr_matrix from sklearn.utils.testing import (assert_array_equal, assert_almost_equal, - assert_false, assert_true, assert_raises) + assert_false, assert_raises, assert_equal) from sklearn.feature_selection.mutual_info_ import ( mutual_info_regression, mutual_info_classif, _compute_mi) @@ -89,6 +89,24 @@ def test_cd(self): I_computed = _compute_mi(x, y, True, False, n_neighbors) assert_almost_equal(I_computed, I_theory, 1) + def test_cd_unique_label(self): + # Test that adding unique label doesn't change MI. + n_samples = 100 + x = np.random.uniform(size=n_samples) > 0.5 + + y = np.empty(n_samples) + mask = x == 0 + y[mask] = np.random.uniform(-1, 1, size=np.sum(mask)) + y[~mask] = np.random.uniform(0, 2, size=np.sum(~mask)) + + mi_1 = _compute_mi(x, y, True, False) + + x = np.hstack((x, 2)) + y = np.hstack((y, 10)) + mi_2 = _compute_mi(x, y, True, False) + + assert_equal(mi_1, mi_2) + class TestMutualInfo(object): # We are going test that feature ordering by MI matches our expectations. From 786999258d54c92e1920ddf7f388215f699aed5b Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Sat, 16 Jan 2016 03:23:19 +0500 Subject: [PATCH 12/25] MAINT: Slightly improve copy logic in _estimate_mi --- sklearn/feature_selection/mutual_info_.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sklearn/feature_selection/mutual_info_.py b/sklearn/feature_selection/mutual_info_.py index b165d4342567b..044520c2cd45f 100644 --- a/sklearn/feature_selection/mutual_info_.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -277,16 +277,15 @@ def _estimate_mi(X, y, discrete_features='auto', discrete_target=False, if np.any(continuous_mask) and issparse(X): raise ValueError("Sparse matrix `X` can't have continuous features.") - if copy: - X = X.copy() - - if not discrete_target and np.any(continuous_mask): - X[:, continuous_mask] = scale(X[:, continuous_mask], - with_mean=False, copy=False) - - # Add small noise to continuous features as advised in Kraskov et. al. rng = check_random_state(random_state) if np.any(continuous_mask): + if not discrete_target: + X[:, continuous_mask] = scale(X[:, continuous_mask], + with_mean=False, copy=copy) + elif copy: + X = X.copy() + + # Add small noise to continuous features as advised in Kraskov et. al. X = X.astype(float) means = np.maximum(1, np.mean(np.abs(X[:, continuous_mask]), axis=0)) X[:, continuous_mask] += 1e-10 * means * rng.randn( From ec172895f8fb6f36ca63eb4b642847c8167e7ee8 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Sat, 16 Jan 2016 04:03:31 +0500 Subject: [PATCH 13/25] DOC: Add short descriptions of methods for mutual info estimation --- sklearn/feature_selection/mutual_info_.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/feature_selection/mutual_info_.py b/sklearn/feature_selection/mutual_info_.py index 044520c2cd45f..0fd7ab00904a3 100644 --- a/sklearn/feature_selection/mutual_info_.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -310,8 +310,8 @@ def mutual_info_regression(X, y, discrete_features='auto', n_neighbors=3, to zero if and only if two random variables are independent, and higher values mean higher dependency. - This function relies on the algorithms of MI estimation described in [2]_ - and [3]_. + The function relies on nonparametric methods based on entropy estimation + from k-nearest neighbors distances. Refer to [2]_ and [3]_. It can be used for univariate features selection, read more in the :ref:`User Guide `. @@ -382,8 +382,8 @@ def mutual_info_classif(X, y, discrete_features='auto', n_neighbors=3, to zero if and only if two random variables are independent, and higher values mean higher dependency. - This function relies on the algorithms of MI estimation described in [2]_ - and [3]_. + The function relies on nonparametric methods based on entropy estimation + from k-nearest neighbors distances. Refer to [2]_ and [3]_. It can be used for univariate features selection, read more in the :ref:`User Guide `. From b0491be8bd6ed0508d9e48094829270859d0d53d Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Sat, 16 Jan 2016 16:36:31 +0500 Subject: [PATCH 14/25] DOC: Add a short explanation of F-test vs MI in narrative doc --- doc/modules/feature_selection.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/modules/feature_selection.rst b/doc/modules/feature_selection.rst index 2b70c25af82a9..8b7bfee654e77 100644 --- a/doc/modules/feature_selection.rst +++ b/doc/modules/feature_selection.rst @@ -92,6 +92,11 @@ and p-values (or only scores for :class:`SelectKBest` and * For classification: :func:`chi2`, :func:`f_classif`, :func:`mutual_info_classif` +The methods based on F-test estimate the degree of linear dependency between +two random variables. On the other hand, mutual information methods can capture +any kind of statistical dependency, but being nonparametric, they require more +samples for accurate estimation. + .. topic:: Feature selection with sparse data If you use sparse data (i.e. data represented as sparse matrices), From d3a497aec4c9fd5eddd4453a9cc6764c60adaf2a Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Sun, 17 Jan 2016 00:39:25 +0500 Subject: [PATCH 15/25] BUG: Fix copy logic for mutual info functions --- sklearn/feature_selection/mutual_info_.py | 7 ++++--- sklearn/feature_selection/tests/test_feature_select.py | 5 +++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sklearn/feature_selection/mutual_info_.py b/sklearn/feature_selection/mutual_info_.py index 0fd7ab00904a3..fe1ba3db40965 100644 --- a/sklearn/feature_selection/mutual_info_.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -279,11 +279,12 @@ def _estimate_mi(X, y, discrete_features='auto', discrete_target=False, rng = check_random_state(random_state) if np.any(continuous_mask): + if copy: + X = X.copy() + if not discrete_target: X[:, continuous_mask] = scale(X[:, continuous_mask], - with_mean=False, copy=copy) - elif copy: - X = X.copy() + with_mean=False, copy=False) # Add small noise to continuous features as advised in Kraskov et. al. X = X.astype(float) diff --git a/sklearn/feature_selection/tests/test_feature_select.py b/sklearn/feature_selection/tests/test_feature_select.py index e23b9a3c20843..b1c965829f2ef 100644 --- a/sklearn/feature_selection/tests/test_feature_select.py +++ b/sklearn/feature_selection/tests/test_feature_select.py @@ -7,6 +7,7 @@ import numpy as np from scipy import stats, sparse +from numpy.testing import run_module_suite from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_raises @@ -614,3 +615,7 @@ def test_mutual_info_regression(): gtruth = np.zeros(20) gtruth[:5] = 1 assert_array_equal(support, gtruth) + + +if __name__ == '__main__': + run_module_suite() From 375b070263c00fead8a1d99b56e0c55ec27d161d Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Mon, 18 Jan 2016 01:16:52 +0500 Subject: [PATCH 16/25] TST: Speed up 2 tests related to mutual info --- .../tests/test_feature_select.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/sklearn/feature_selection/tests/test_feature_select.py b/sklearn/feature_selection/tests/test_feature_select.py index b1c965829f2ef..d2a73334299d5 100644 --- a/sklearn/feature_selection/tests/test_feature_select.py +++ b/sklearn/feature_selection/tests/test_feature_select.py @@ -559,61 +559,61 @@ def test_no_feature_selected(): assert_equal(X_selected.shape, (40, 0)) -def test_mutual_info_classification(): - X, y = make_classification(n_samples=200, n_features=20, - n_informative=3, n_redundant=2, - n_repeated=0, n_classes=8, +def test_mutual_info_classif(): + X, y = make_classification(n_samples=100, n_features=5, + n_informative=1, n_redundant=1, + n_repeated=0, n_classes=2, n_clusters_per_class=1, flip_y=0.0, class_sep=10, shuffle=False, random_state=0) # Test in KBest mode. - univariate_filter = SelectKBest(mutual_info_classif, k=5) + univariate_filter = SelectKBest(mutual_info_classif, k=2) X_r = univariate_filter.fit(X, y).transform(X) X_r2 = GenericUnivariateSelect( - mutual_info_classif, mode='k_best', param=5).fit(X, y).transform(X) + mutual_info_classif, mode='k_best', param=2).fit(X, y).transform(X) assert_array_equal(X_r, X_r2) support = univariate_filter.get_support() - gtruth = np.zeros(20) - gtruth[:5] = 1 + gtruth = np.zeros(5) + gtruth[:2] = 1 assert_array_equal(support, gtruth) # Test in Percentile mode. - univariate_filter = SelectPercentile(mutual_info_classif, percentile=25) + univariate_filter = SelectPercentile(mutual_info_classif, percentile=40) X_r = univariate_filter.fit(X, y).transform(X) X_r2 = GenericUnivariateSelect( - mutual_info_classif, mode='percentile', param=25).fit(X, y).transform(X) + mutual_info_classif, mode='percentile', param=40).fit(X, y).transform(X) assert_array_equal(X_r, X_r2) support = univariate_filter.get_support() - gtruth = np.zeros(20) - gtruth[:5] = 1 + gtruth = np.zeros(5) + gtruth[:2] = 1 assert_array_equal(support, gtruth) def test_mutual_info_regression(): - X, y = make_regression(n_samples=200, n_features=20, n_informative=5, + X, y = make_regression(n_samples=100, n_features=10, n_informative=2, shuffle=False, random_state=0, noise=10) # Test in KBest mode. - univariate_filter = SelectKBest(mutual_info_regression, k=5) + univariate_filter = SelectKBest(mutual_info_regression, k=2) X_r = univariate_filter.fit(X, y).transform(X) assert_best_scores_kept(univariate_filter) X_r2 = GenericUnivariateSelect( - mutual_info_regression, mode='k_best', param=5).fit(X, y).transform(X) + mutual_info_regression, mode='k_best', param=2).fit(X, y).transform(X) assert_array_equal(X_r, X_r2) support = univariate_filter.get_support() - gtruth = np.zeros(20) - gtruth[:5] = 1 + gtruth = np.zeros(10) + gtruth[:2] = 1 assert_array_equal(support, gtruth) # Test in Percentile mode. - univariate_filter = SelectPercentile(mutual_info_regression, percentile=25) + univariate_filter = SelectPercentile(mutual_info_regression, percentile=20) X_r = univariate_filter.fit(X, y).transform(X) X_r2 = GenericUnivariateSelect(mutual_info_regression, mode='percentile', - param=25).fit(X, y).transform(X) + param=20).fit(X, y).transform(X) assert_array_equal(X_r, X_r2) support = univariate_filter.get_support() - gtruth = np.zeros(20) - gtruth[:5] = 1 + gtruth = np.zeros(10) + gtruth[:2] = 1 assert_array_equal(support, gtruth) From d60636a4a80018b40022c389c0af8eb9d955101e Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Mon, 18 Jan 2016 02:42:10 +0500 Subject: [PATCH 17/25] DOC: Small fixes in mutual_info_.py documentation --- sklearn/feature_selection/mutual_info_.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/feature_selection/mutual_info_.py b/sklearn/feature_selection/mutual_info_.py index fe1ba3db40965..d8ca8d0782dc0 100644 --- a/sklearn/feature_selection/mutual_info_.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -20,8 +20,9 @@ def _compute_mi_cc(x, y, n_neighbors): Parameters ---------- - x, y : ndarray - Samples of random variables, 1-d arrays of identical shape. + x, y : ndarray, shape (n_samples,) + Samples of two continuous random variables, must have an identical + shape. n_neighbors : int Number of nearest neighbors to search for each point, see [1]_. @@ -77,11 +78,11 @@ def _compute_mi_cd(c, d, n_neighbors): Parameters ---------- - c : ndarray - Samples from a continuous random variable. + c : ndarray, shape (n_samples,) + Samples of a continuous random variable. - d : ndarray - Samples from a discrete random variable. + d : ndarray, shape (n_samples,) + Samples of a discrete random variable. n_neighbors : int Number of nearest neighbors to search for each point, see [1]_. From 094a077af5578277a5d703e6aa238b96efa81707 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Mon, 18 Jan 2016 02:50:54 +0500 Subject: [PATCH 18/25] MAINT: Small refactoring in mutual_info_.py --- sklearn/feature_selection/mutual_info_.py | 37 ++++++----------------- 1 file changed, 9 insertions(+), 28 deletions(-) diff --git a/sklearn/feature_selection/mutual_info_.py b/sklearn/feature_selection/mutual_info_.py index d8ca8d0782dc0..5829b1b76f30d 100644 --- a/sklearn/feature_selection/mutual_info_.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -160,32 +160,6 @@ def _compute_mi(x, y, x_discrete, y_discrete, n_neighbors=3): return _compute_mi_cc(x, y, n_neighbors) -def _get_column(X, i): - """Get column of a matrix. - - Parameters - ---------- - X : ndarray or csc_matrix, shape (n_samples, n_features) - Matrix from which to get a column. - - i : int - Column index. - - Returns - ------- - xi : ndarray, shape (n_samples,) - i-th column of `X` in dense format. - """ - if issparse(X): - x = np.zeros(X.shape[0]) - start_ptr, end_ptr = X.indptr[i], X.indptr[i + 1] - x[X.indices[start_ptr:end_ptr]] = X.data[start_ptr:end_ptr] - else: - x = X[:, i] - - return x - - def _iterate_columns(X, columns=None): """Iterate over columns of a matrix. @@ -205,8 +179,15 @@ def _iterate_columns(X, columns=None): if columns is None: columns = range(X.shape[1]) - for i in columns: - yield _get_column(X, i) + if issparse(X): + for i in columns: + x = np.zeros(X.shape[0]) + start_ptr, end_ptr = X.indptr[i], X.indptr[i + 1] + x[X.indices[start_ptr:end_ptr]] = X.data[start_ptr:end_ptr] + yield x + else: + for i in columns: + yield X[:, i] def _estimate_mi(X, y, discrete_features='auto', discrete_target=False, From 5b3f51515ba0f64079693be28bee28e3db469e6a Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Mon, 18 Jan 2016 02:57:57 +0500 Subject: [PATCH 19/25] MAINT: Get rid of classes in test_mutual_info.py --- .../tests/test_mutual_info.py | 320 +++++++++--------- 1 file changed, 162 insertions(+), 158 deletions(-) diff --git a/sklearn/feature_selection/tests/test_mutual_info.py b/sklearn/feature_selection/tests/test_mutual_info.py index a68d4ce0fd4be..f9b86777dcbe3 100644 --- a/sklearn/feature_selection/tests/test_mutual_info.py +++ b/sklearn/feature_selection/tests/test_mutual_info.py @@ -10,179 +10,183 @@ mutual_info_regression, mutual_info_classif, _compute_mi) -class TestMIComputation(object): - def test_dd(self): - # In discrete case computations are straightforward and can be done - # by hand on given vectors. - x = np.array([0, 1, 1, 0, 0]) - y = np.array([1, 0, 0, 0, 1]) - - H_x = H_y = -(3/5) * np.log(3/5) - (2/5) * np.log(2/5) - H_xy = -1/5 * np.log(1/5) - 2/5 * np.log(2/5) - 2/5 * np.log(2/5) - I_xy = H_x + H_y - H_xy - - assert_almost_equal(_compute_mi(x, y, True, True), I_xy) - - def test_cc(self): - # For two continuous variables a good approach is to test on bivariate - # normal distribution, where mutual information is known. - - # Mean of the distribution, irrelevant for mutual information. - mean = np.zeros(2) - - # Setup covariance matrix with correlation coeff. equal 0.5. - sigma_1 = 1 - sigma_2 = 10 - corr = 0.5 - cov = np.array([ - [sigma_1**2, corr * sigma_1 * sigma_2], - [corr * sigma_1 * sigma_2, sigma_2**2] - ]) - - # True theoretical mutual information. - I_theory = (np.log(sigma_1) + np.log(sigma_2) - - 0.5 * np.log(np.linalg.det(cov))) - - np.random.seed(0) - Z = np.random.multivariate_normal(mean, cov, size=1000) - - x, y = Z[:, 0], Z[:, 1] - - # Theory and computed values won't be very close, assert that the - # first figures after decimal point match. - for n_neighbors in [3, 5, 7]: - I_computed = _compute_mi(x, y, False, False, n_neighbors) - assert_almost_equal(I_computed, I_theory, 1) +def test_compute_mi_dd(): + # In discrete case computations are straightforward and can be done + # by hand on given vectors. + x = np.array([0, 1, 1, 0, 0]) + y = np.array([1, 0, 0, 0, 1]) + + H_x = H_y = -(3/5) * np.log(3/5) - (2/5) * np.log(2/5) + H_xy = -1/5 * np.log(1/5) - 2/5 * np.log(2/5) - 2/5 * np.log(2/5) + I_xy = H_x + H_y - H_xy + + assert_almost_equal(_compute_mi(x, y, True, True), I_xy) + + +def test_compute_mi_cc(): + # For two continuous variables a good approach is to test on bivariate + # normal distribution, where mutual information is known. + + # Mean of the distribution, irrelevant for mutual information. + mean = np.zeros(2) - def test_cd(self): - # To test define a joint distribution as follows: - # p(x, y) = p(x) p(y | x) - # X ~ Bernoulli(p) - # (Y | x = 0) ~ Uniform(-1, 1) - # (Y | x = 1) ~ Uniform(0, 2) + # Setup covariance matrix with correlation coeff. equal 0.5. + sigma_1 = 1 + sigma_2 = 10 + corr = 0.5 + cov = np.array([ + [sigma_1**2, corr * sigma_1 * sigma_2], + [corr * sigma_1 * sigma_2, sigma_2**2] + ]) - # Use the following formula for mutual information: - # I(X; Y) = H(Y) - H(Y | X) - # Two entropies can be computed by hand: - # H(Y) = -(1-p)/2 * ln((1-p)/2) - p/2*log(p/2) - 1/2*log(1/2) - # H(Y | X) = ln(2) + # True theoretical mutual information. + I_theory = (np.log(sigma_1) + np.log(sigma_2) - + 0.5 * np.log(np.linalg.det(cov))) - # Now we need to implement sampling from out distribution, which is - # done easily using conditional distribution logic. + np.random.seed(0) + Z = np.random.multivariate_normal(mean, cov, size=1000) - n_samples = 1000 - np.random.seed(0) + x, y = Z[:, 0], Z[:, 1] - for p in [0.3, 0.5, 0.7]: - x = np.random.uniform(size=n_samples) > p + # Theory and computed values won't be very close, assert that the + # first figures after decimal point match. + for n_neighbors in [3, 5, 7]: + I_computed = _compute_mi(x, y, False, False, n_neighbors) + assert_almost_equal(I_computed, I_theory, 1) - y = np.empty(n_samples) - mask = x == 0 - y[mask] = np.random.uniform(-1, 1, size=np.sum(mask)) - y[~mask] = np.random.uniform(0, 2, size=np.sum(~mask)) - I_theory = -0.5 * ((1 - p) * np.log(0.5 * (1 - p)) + - p * np.log(0.5 * p) + np.log(0.5)) - np.log(2) +def test_compute_mi_cd(): + # To test define a joint distribution as follows: + # p(x, y) = p(x) p(y | x) + # X ~ Bernoulli(p) + # (Y | x = 0) ~ Uniform(-1, 1) + # (Y | x = 1) ~ Uniform(0, 2) - # Assert the same tolerance. - for n_neighbors in [3, 5, 7]: - I_computed = _compute_mi(x, y, True, False, n_neighbors) - assert_almost_equal(I_computed, I_theory, 1) + # Use the following formula for mutual information: + # I(X; Y) = H(Y) - H(Y | X) + # Two entropies can be computed by hand: + # H(Y) = -(1-p)/2 * ln((1-p)/2) - p/2*log(p/2) - 1/2*log(1/2) + # H(Y | X) = ln(2) - def test_cd_unique_label(self): - # Test that adding unique label doesn't change MI. - n_samples = 100 - x = np.random.uniform(size=n_samples) > 0.5 + # Now we need to implement sampling from out distribution, which is + # done easily using conditional distribution logic. + + n_samples = 1000 + np.random.seed(0) + + for p in [0.3, 0.5, 0.7]: + x = np.random.uniform(size=n_samples) > p y = np.empty(n_samples) mask = x == 0 y[mask] = np.random.uniform(-1, 1, size=np.sum(mask)) y[~mask] = np.random.uniform(0, 2, size=np.sum(~mask)) - mi_1 = _compute_mi(x, y, True, False) - - x = np.hstack((x, 2)) - y = np.hstack((y, 10)) - mi_2 = _compute_mi(x, y, True, False) - - assert_equal(mi_1, mi_2) - - -class TestMutualInfo(object): - # We are going test that feature ordering by MI matches our expectations. - def test_discrete(self): - X = np.array([[0, 0, 0], - [1, 1, 0], - [2, 0, 1], - [2, 0, 1], - [2, 0, 1]]) - y = np.array([0, 1, 2, 2, 1]) - - # Here X[:, 0] is the most informative feature, and X[:, 1] is weakly - # informative. - mi = mutual_info_classif(X, y, discrete_features=True) - assert_array_equal(np.argsort(-mi), np.array([0, 2, 1])) - - def test_continuous(self): - # We generate sample from multivariate normal distribution, using - # transformation from initially uncorrelated variables. The zero - # variables after transformation is selected as the target vector, - # it has the strongest correlation with the variable 2, and - # the weakest correlation with the variable 1. - T = np.array([ - [1, 0.5, 2, 1], - [0, 1, 0.1, 0.0], - [0, 0.1, 1, 0.1], - [0, 0.1, 0.1, 1] - ]) - cov = T.dot(T.T) - mean = np.zeros(4) - - np.random.seed(0) - Z = np.random.multivariate_normal(mean, cov, size=1000) - X = Z[:, 1:] - y = Z[:, 0] - - mi = mutual_info_regression(X, y, random_state=0) - assert_array_equal(np.argsort(-mi), np.array([1, 2, 0])) - - def test_mixed(self): - # Here the target is discrete and there are two continuous and one - # discrete feature. The idea of this test is clear from the code. - np.random.seed(0) - X = np.random.rand(1000, 3) - X[:, 1] += X[:, 0] - y = ((0.5 * X[:, 0] + X[:, 2]) > 0.5).astype(int) - X[:, 2] = X[:, 2] > 0.5 - - mi = mutual_info_classif(X, y, discrete_features=[2], random_state=0) - assert_array_equal(np.argsort(-mi), [2, 0, 1]) - - def test_discrete_features_option(self): - X = np.array([[0, 0, 0], - [1, 1, 0], - [2, 0, 1], - [2, 0, 1], - [2, 0, 1]], dtype=float) - y = np.array([0, 1, 2, 2, 1], dtype=float) - X_csr = csr_matrix(X) - - for mutual_info in (mutual_info_regression, mutual_info_classif): - assert_raises(ValueError, mutual_info_regression, X_csr, y, - discrete_features=False) - - mi_1 = mutual_info(X, y, discrete_features='auto', random_state=0) - mi_2 = mutual_info(X, y, discrete_features=False, random_state=0) - - mi_3 = mutual_info(X_csr, y, discrete_features='auto', - random_state=0) - mi_4 = mutual_info(X_csr, y, discrete_features=True, - random_state=0) - - assert_array_equal(mi_1, mi_2) - assert_array_equal(mi_3, mi_4) - - assert_false(np.allclose(mi_1, mi_3)) + I_theory = -0.5 * ((1 - p) * np.log(0.5 * (1 - p)) + + p * np.log(0.5 * p) + np.log(0.5)) - np.log(2) + + # Assert the same tolerance. + for n_neighbors in [3, 5, 7]: + I_computed = _compute_mi(x, y, True, False, n_neighbors) + assert_almost_equal(I_computed, I_theory, 1) + + +def test_compute_mi_cd_unique_label(): + # Test that adding unique label doesn't change MI. + n_samples = 100 + x = np.random.uniform(size=n_samples) > 0.5 + + y = np.empty(n_samples) + mask = x == 0 + y[mask] = np.random.uniform(-1, 1, size=np.sum(mask)) + y[~mask] = np.random.uniform(0, 2, size=np.sum(~mask)) + + mi_1 = _compute_mi(x, y, True, False) + + x = np.hstack((x, 2)) + y = np.hstack((y, 10)) + mi_2 = _compute_mi(x, y, True, False) + + assert_equal(mi_1, mi_2) + + +# We are going test that feature ordering by MI matches our expectations. +def test_mutual_info_classif_discrete(): + X = np.array([[0, 0, 0], + [1, 1, 0], + [2, 0, 1], + [2, 0, 1], + [2, 0, 1]]) + y = np.array([0, 1, 2, 2, 1]) + + # Here X[:, 0] is the most informative feature, and X[:, 1] is weakly + # informative. + mi = mutual_info_classif(X, y, discrete_features=True) + assert_array_equal(np.argsort(-mi), np.array([0, 2, 1])) + + +def test_mutual_info_regression(): + # We generate sample from multivariate normal distribution, using + # transformation from initially uncorrelated variables. The zero + # variables after transformation is selected as the target vector, + # it has the strongest correlation with the variable 2, and + # the weakest correlation with the variable 1. + T = np.array([ + [1, 0.5, 2, 1], + [0, 1, 0.1, 0.0], + [0, 0.1, 1, 0.1], + [0, 0.1, 0.1, 1] + ]) + cov = T.dot(T.T) + mean = np.zeros(4) + + np.random.seed(0) + Z = np.random.multivariate_normal(mean, cov, size=1000) + X = Z[:, 1:] + y = Z[:, 0] + + mi = mutual_info_regression(X, y, random_state=0) + assert_array_equal(np.argsort(-mi), np.array([1, 2, 0])) + + +def test_mutual_info_classif_mixed(): + # Here the target is discrete and there are two continuous and one + # discrete feature. The idea of this test is clear from the code. + np.random.seed(0) + X = np.random.rand(1000, 3) + X[:, 1] += X[:, 0] + y = ((0.5 * X[:, 0] + X[:, 2]) > 0.5).astype(int) + X[:, 2] = X[:, 2] > 0.5 + + mi = mutual_info_classif(X, y, discrete_features=[2], random_state=0) + assert_array_equal(np.argsort(-mi), [2, 0, 1]) + + +def test_mutual_info_options(): + X = np.array([[0, 0, 0], + [1, 1, 0], + [2, 0, 1], + [2, 0, 1], + [2, 0, 1]], dtype=float) + y = np.array([0, 1, 2, 2, 1], dtype=float) + X_csr = csr_matrix(X) + + for mutual_info in (mutual_info_regression, mutual_info_classif): + assert_raises(ValueError, mutual_info_regression, X_csr, y, + discrete_features=False) + + mi_1 = mutual_info(X, y, discrete_features='auto', random_state=0) + mi_2 = mutual_info(X, y, discrete_features=False, random_state=0) + + mi_3 = mutual_info(X_csr, y, discrete_features='auto', + random_state=0) + mi_4 = mutual_info(X_csr, y, discrete_features=True, + random_state=0) + + assert_array_equal(mi_1, mi_2) + assert_array_equal(mi_3, mi_4) + + assert_false(np.allclose(mi_1, mi_3)) if __name__ == '__main__': From b48a10870a1a3fbb858b8e4e9e1580747c84eba5 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Wed, 20 Jan 2016 17:06:45 +0500 Subject: [PATCH 20/25] DOC: Add one more reference for mutual info methods --- sklearn/feature_selection/mutual_info_.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sklearn/feature_selection/mutual_info_.py b/sklearn/feature_selection/mutual_info_.py index 5829b1b76f30d..f99a0c157da04 100644 --- a/sklearn/feature_selection/mutual_info_.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -294,7 +294,8 @@ def mutual_info_regression(X, y, discrete_features='auto', n_neighbors=3, values mean higher dependency. The function relies on nonparametric methods based on entropy estimation - from k-nearest neighbors distances. Refer to [2]_ and [3]_. + from k-nearest neighbors distances as described in [2]_ and [3]_. Both + methods are based on the idea originally proposed in [4]_. It can be used for univariate features selection, read more in the :ref:`User Guide `. @@ -351,6 +352,8 @@ def mutual_info_regression(X, y, discrete_features='auto', n_neighbors=3, information". Phys. Rev. E 69, 2004. .. [3] B. C. Ross "Mutual Information between Discrete and Continuous Data Sets". PLoS ONE 9(2), 2014. + .. [4] L. F. Kozachenko, N. N. Leonenko, “Sample Estimate of the Entropy + of a Random Vector”, Probl. Peredachi Inf., 23:2 (1987), 9–16 """ return _estimate_mi(X, y, discrete_features, False, n_neighbors, copy, random_state) @@ -366,7 +369,8 @@ def mutual_info_classif(X, y, discrete_features='auto', n_neighbors=3, values mean higher dependency. The function relies on nonparametric methods based on entropy estimation - from k-nearest neighbors distances. Refer to [2]_ and [3]_. + from k-nearest neighbors distances as described in [2]_ and [3]_. Both + methods are based on the idea originally proposed in [4]_. It can be used for univariate features selection, read more in the :ref:`User Guide `. @@ -423,6 +427,8 @@ def mutual_info_classif(X, y, discrete_features='auto', n_neighbors=3, information". Phys. Rev. E 69, 2004. .. [3] B. C. Ross "Mutual Information between Discrete and Continuous Data Sets". PLoS ONE 9(2), 2014. + .. [4] L. F. Kozachenko, N. N. Leonenko, “Sample Estimate of the Entropy + of a Random Vector”, Probl. Peredachi Inf., 23:2 (1987), 9–16 """ check_classification_targets(y) return _estimate_mi(X, y, discrete_features, True, n_neighbors, From e1bc05628aba16ddf0620a8314a79c23fb787e99 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Wed, 20 Jan 2016 17:24:10 +0500 Subject: [PATCH 21/25] MAINT: Add a clarification comment in mutual_info_.py --- sklearn/feature_selection/mutual_info_.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/feature_selection/mutual_info_.py b/sklearn/feature_selection/mutual_info_.py index f99a0c157da04..650ccd9697738 100644 --- a/sklearn/feature_selection/mutual_info_.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -51,12 +51,15 @@ def _compute_mi_cc(x, y, n_neighbors): y = y.reshape((-1, 1)) xy = np.hstack((x, y)) + # Here we rely on NearestNeighbors to select the fastest algorithm. nn = NearestNeighbors(metric='chebyshev', n_neighbors=n_neighbors) nn.fit(xy) radius = nn.kneighbors()[0] radius = np.nextafter(radius[:, -1], 0) + # Algorithm is selected explicitly to allow passing an array as radius + # later (not all algorithms support this). nn.set_params(algorithm='kd_tree') nn.fit(x) From a36edf2a3d4fe2d09d885307afb5fa271c7cc132 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Wed, 20 Jan 2016 17:46:30 +0500 Subject: [PATCH 22/25] DOC: Modify SelectKBest and SelectPercentile docstrings slightly --- sklearn/feature_selection/univariate_selection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/feature_selection/univariate_selection.py b/sklearn/feature_selection/univariate_selection.py index 9ec3fe17cf961..d365ec221575f 100644 --- a/sklearn/feature_selection/univariate_selection.py +++ b/sklearn/feature_selection/univariate_selection.py @@ -365,7 +365,7 @@ class SelectPercentile(_BaseFilter): Scores of features. pvalues_ : array-like, shape=(n_features,) - p-values of feature scores, None if `score_func` returned scores only. + p-values of feature scores, None if `score_func` returned only scores. Notes ----- @@ -425,7 +425,7 @@ class SelectKBest(_BaseFilter): ---------- score_func : callable Function taking two arrays X and y, and returning a pair of arrays - (scores, pvalues) or a single scores array. + (scores, pvalues) or a single array with scores. k : int or "all", optional, default=10 Number of top features to select. @@ -437,7 +437,7 @@ class SelectKBest(_BaseFilter): Scores of features. pvalues_ : array-like, shape=(n_features,) - p-values of feature scores, None if `score_func` returned scores only. + p-values of feature scores, None if `score_func` returned only scores. Notes ----- From e716c6489c1a2633d81d741b58dc0a932b11d145 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Wed, 20 Jan 2016 17:48:29 +0500 Subject: [PATCH 23/25] MAINT: Mention mutual info methods in whats_new.rst --- doc/whats_new.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 720d7c3cb6ca7..b7cac4ad04653 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -15,6 +15,13 @@ Changelog New features ............ + - Added two functions for mutual information estimation: + :func:`feature_selection.mutual_info_classif` and + :func:`feature_selection.mutual_info_regression`. These functions can be + used in :class:`feature_selection.SelectKBest` and + :class:`feature_selection.SelectPercentile`, which now accept callable + returning only `scores`. By `Nikolay Mayorov`_. + - The Gaussian Process module has been reimplemented and now offers classification and regression estimators through :class:`gaussian_process.GaussianProcessClassifier` and :class:`gaussian_process.GaussianProcessRegressor`. Among other things, the new From daa73c7429082f0fea7665359d33a5f8d65b5989 Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Wed, 20 Jan 2016 18:22:10 +0500 Subject: [PATCH 24/25] BUG: Remove non-ASCII symbols from mutual_info_.py --- sklearn/feature_selection/mutual_info_.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/feature_selection/mutual_info_.py b/sklearn/feature_selection/mutual_info_.py index 650ccd9697738..0b205c2011c7a 100644 --- a/sklearn/feature_selection/mutual_info_.py +++ b/sklearn/feature_selection/mutual_info_.py @@ -355,8 +355,8 @@ def mutual_info_regression(X, y, discrete_features='auto', n_neighbors=3, information". Phys. Rev. E 69, 2004. .. [3] B. C. Ross "Mutual Information between Discrete and Continuous Data Sets". PLoS ONE 9(2), 2014. - .. [4] L. F. Kozachenko, N. N. Leonenko, “Sample Estimate of the Entropy - of a Random Vector”, Probl. Peredachi Inf., 23:2 (1987), 9–16 + .. [4] L. F. Kozachenko, N. N. Leonenko, "Sample Estimate of the Entropy + of a Random Vector", Probl. Peredachi Inf., 23:2 (1987), 9-16 """ return _estimate_mi(X, y, discrete_features, False, n_neighbors, copy, random_state) @@ -430,8 +430,8 @@ def mutual_info_classif(X, y, discrete_features='auto', n_neighbors=3, information". Phys. Rev. E 69, 2004. .. [3] B. C. Ross "Mutual Information between Discrete and Continuous Data Sets". PLoS ONE 9(2), 2014. - .. [4] L. F. Kozachenko, N. N. Leonenko, “Sample Estimate of the Entropy - of a Random Vector”, Probl. Peredachi Inf., 23:2 (1987), 9–16 + .. [4] L. F. Kozachenko, N. N. Leonenko, "Sample Estimate of the Entropy + of a Random Vector:, Probl. Peredachi Inf., 23:2 (1987), 9-16 """ check_classification_targets(y) return _estimate_mi(X, y, discrete_features, True, n_neighbors, From 4cc82a3f43bd279e6864030122e4902464f74c3b Mon Sep 17 00:00:00 2001 From: Nikolay Mayorov Date: Fri, 22 Jan 2016 00:23:35 +0500 Subject: [PATCH 25/25] MAINT: Modify whats_new item related to mutual information --- doc/whats_new.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index b7cac4ad04653..e6fb4f8fc4ffa 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -20,7 +20,7 @@ New features :func:`feature_selection.mutual_info_regression`. These functions can be used in :class:`feature_selection.SelectKBest` and :class:`feature_selection.SelectPercentile`, which now accept callable - returning only `scores`. By `Nikolay Mayorov`_. + returning only `scores`. By `Andrea Bravi`_ and `Nikolay Mayorov`_. - The Gaussian Process module has been reimplemented and now offers classification and regression estimators through :class:`gaussian_process.GaussianProcessClassifier` @@ -4044,3 +4044,6 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson. .. _Imaculate: https://github.com/Imaculate .. _Bernardo Stein: https://github.com/DanielSidhion + +.. _Andrea Bravi: https://github.com/AndreaBravi +