Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH add n_jobs to mutual_info_regression and mutual_info_classif #28085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ Changelog
by passing a function in place of a strategy name.
:pr:`28053` by :user:`Mark Elliot <mark-thm>`.

:mod:`sklearn.feature_selection`
................................

- |Enhancement| :func:`feature_selection.mutual_info_regression` and
:func:`feature_selection.mutual_info_classif` now support `n_jobs` parameter.
:pr:`28085` by :user:`Neto Menoci <netomenoci>` and
:user:`Florin Andrei <FlorinAndrei>`.

:mod:`sklearn.metrics`
......................

Expand Down
80 changes: 73 additions & 7 deletions sklearn/feature_selection/_mutual_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..utils import check_random_state
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils.multiclass import check_classification_targets
from ..utils.parallel import Parallel, delayed
from ..utils.validation import check_array, check_X_y


Expand Down Expand Up @@ -201,11 +202,13 @@ def _iterate_columns(X, columns=None):
def _estimate_mi(
X,
y,
*,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used keyword arguments to make the code more readable in the _estimate_mi function call and also use keyword-only arguments in the _estimate_mi definition.

I guess that's fine since _estimate_mi is private. @glemaitre do you agree?

Copy link
Member

@glemaitre glemaitre Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I do agree.

discrete_features="auto",
discrete_target=False,
n_neighbors=3,
copy=True,
random_state=None,
n_jobs=None,
):
"""Estimate mutual information between the features and the target.

Expand Down Expand Up @@ -242,6 +245,16 @@ def _estimate_mi(
Pass an int for reproducible results across multiple function calls.
See :term:`Glossary <random_state>`.

n_jobs : int, default=None
The number of jobs to use for computing the mutual information.
The parallelization is done on the columns of `X`.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.

.. versionadded:: 1.5


Returns
-------
mi : ndarray, shape (n_features,)
Expand Down Expand Up @@ -301,10 +314,10 @@ def _estimate_mi(
* rng.standard_normal(size=n_samples)
)

mi = [
_compute_mi(x, y, discrete_feature, discrete_target, n_neighbors)
mi = Parallel(n_jobs=n_jobs)(
delayed(_compute_mi)(x, y, discrete_feature, discrete_target, n_neighbors)
for x, discrete_feature in zip(_iterate_columns(X), discrete_mask)
]
)

return np.array(mi)

Expand All @@ -317,11 +330,19 @@ def _estimate_mi(
"n_neighbors": [Interval(Integral, 1, None, closed="left")],
"copy": ["boolean"],
"random_state": ["random_state"],
"n_jobs": [Integral, None],
},
prefer_skip_nested_validation=True,
)
def mutual_info_regression(
X, y, *, discrete_features="auto", n_neighbors=3, copy=True, random_state=None
X,
y,
*,
discrete_features="auto",
n_neighbors=3,
copy=True,
random_state=None,
n_jobs=None,
):
"""Estimate mutual information for a continuous target variable.

Expand Down Expand Up @@ -367,6 +388,16 @@ def mutual_info_regression(
Pass an int for reproducible results across multiple function calls.
See :term:`Glossary <random_state>`.

n_jobs : int, default=None
The number of jobs to use for computing the mutual information.
The parallelization is done on the columns of `X`.

``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.

.. versionadded:: 1.5

Returns
-------
mi : ndarray, shape (n_features,)
Expand Down Expand Up @@ -407,7 +438,16 @@ def mutual_info_regression(
>>> mutual_info_regression(X, y)
array([0.1..., 2.6... , 0.0...])
"""
return _estimate_mi(X, y, discrete_features, False, n_neighbors, copy, random_state)
return _estimate_mi(
X,
y,
discrete_features=discrete_features,
discrete_target=False,
n_neighbors=n_neighbors,
copy=copy,
random_state=random_state,
n_jobs=n_jobs,
)


@validate_params(
Expand All @@ -418,11 +458,19 @@ def mutual_info_regression(
"n_neighbors": [Interval(Integral, 1, None, closed="left")],
"copy": ["boolean"],
"random_state": ["random_state"],
"n_jobs": [Integral, None],
},
prefer_skip_nested_validation=True,
)
def mutual_info_classif(
X, y, *, discrete_features="auto", n_neighbors=3, copy=True, random_state=None
X,
y,
*,
discrete_features="auto",
n_neighbors=3,
copy=True,
random_state=None,
n_jobs=None,
):
"""Estimate mutual information for a discrete target variable.

Expand Down Expand Up @@ -468,6 +516,15 @@ def mutual_info_classif(
Pass an int for reproducible results across multiple function calls.
See :term:`Glossary <random_state>`.

n_jobs : int, default=None
The number of jobs to use for computing the mutual information.
The parallelization is done on the columns of `X`.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.

.. versionadded:: 1.5

Returns
-------
mi : ndarray, shape (n_features,)
Expand Down Expand Up @@ -511,4 +568,13 @@ def mutual_info_classif(
0. , 0. , 0. , 0. , 0. ])
"""
check_classification_targets(y)
return _estimate_mi(X, y, discrete_features, True, n_neighbors, copy, random_state)
return _estimate_mi(
X,
y,
discrete_features=discrete_features,
discrete_target=True,
n_neighbors=n_neighbors,
copy=copy,
random_state=random_state,
n_jobs=n_jobs,
)
16 changes: 16 additions & 0 deletions sklearn/feature_selection/tests/test_mutual_info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest

from sklearn.datasets import make_classification, make_regression
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
from sklearn.feature_selection._mutual_info import _compute_mi
from sklearn.utils import check_random_state
Expand Down Expand Up @@ -252,3 +253,18 @@ def test_mutual_info_regression_X_int_dtype(global_random_seed):
expected = mutual_info_regression(X_float, y, random_state=global_random_seed)
result = mutual_info_regression(X, y, random_state=global_random_seed)
assert_allclose(result, expected)


@pytest.mark.parametrize(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@netomenoci I push a piece of code that show how to make the parallelization if you are interested in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, thanks :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi bary Wery good nice good luck bary

"mutual_info_func, data_generator",
[
(mutual_info_regression, make_regression),
(mutual_info_classif, make_classification),
],
)
def test_mutual_info_n_jobs(global_random_seed, mutual_info_func, data_generator):
"""Check that results are consistent with different `n_jobs`."""
X, y = data_generator(random_state=global_random_seed)
single_job = mutual_info_func(X, y, random_state=global_random_seed, n_jobs=1)
multi_job = mutual_info_func(X, y, random_state=global_random_seed, n_jobs=2)
assert_allclose(single_job, multi_job)