Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix _estimate_mi discrete_features str and value check #13497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 1, 2019
14 changes: 8 additions & 6 deletions sklearn/feature_selection/mutual_info_.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..preprocessing import scale
from ..utils import check_random_state
from ..utils.fixes import _astype_copy_false
from ..utils.validation import check_X_y
from ..utils.validation import check_array, check_X_y
from ..utils.multiclass import check_classification_targets


Expand Down Expand Up @@ -247,14 +247,16 @@ def _estimate_mi(X, y, discrete_features='auto', discrete_target=False,
X, y = check_X_y(X, y, accept_sparse='csc', y_numeric=not discrete_target)
n_samples, n_features = X.shape

if discrete_features == 'auto':
discrete_features = issparse(X)

if isinstance(discrete_features, bool):
if isinstance(discrete_features, (str, bool)):
if isinstance(discrete_features, str):
if discrete_features == 'auto':
discrete_features = issparse(X)
else:
raise ValueError("Invalid string value for discrete_features.")
discrete_mask = np.empty(n_features, dtype=bool)
discrete_mask.fill(discrete_features)
else:
discrete_features = np.asarray(discrete_features)
discrete_features = check_array(discrete_features, ensure_2d=False)
if discrete_features.dtype != 'bool':
discrete_mask = np.zeros(n_features, dtype=bool)
discrete_mask[discrete_features] = True
Expand Down
18 changes: 13 additions & 5 deletions sklearn/feature_selection/tests/test_mutual_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,26 @@ def test_mutual_info_options():
X_csr = csr_matrix(X)

for mutual_info in (mutual_info_regression, mutual_info_classif):
assert_raises(ValueError, mutual_info_regression, X_csr, y,
assert_raises(ValueError, mutual_info, X_csr, y,
discrete_features=False)
assert_raises(ValueError, mutual_info, X, y,
discrete_features='manual')
assert_raises(ValueError, mutual_info, X_csr, y,
discrete_features=[True, False, True])
assert_raises(IndexError, mutual_info, X, y,
discrete_features=[True, False, True, False])
assert_raises(IndexError, mutual_info, X, y, discrete_features=[1, 4])

mi_1 = mutual_info(X, y, discrete_features='auto', random_state=0)
mi_2 = mutual_info(X, y, discrete_features=False, random_state=0)

mi_3 = mutual_info(X_csr, y, discrete_features='auto',
random_state=0)
mi_4 = mutual_info(X_csr, y, discrete_features=True,
mi_3 = mutual_info(X_csr, y, discrete_features='auto', random_state=0)
mi_4 = mutual_info(X_csr, y, discrete_features=True, random_state=0)
mi_5 = mutual_info(X, y, discrete_features=[True, False, True],
random_state=0)
mi_6 = mutual_info(X, y, discrete_features=[0, 2], random_state=0)

assert_array_equal(mi_1, mi_2)
assert_array_equal(mi_3, mi_4)
assert_array_equal(mi_5, mi_6)

assert not np.allclose(mi_1, mi_3)