Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 06e566e

Browse files
netomenociglemaitrelesteve
authored
ENH add n_jobs to mutual_info_regression and mutual_info_classif (#28085)
Co-authored-by: Guillaume Lemaitre <[email protected]> Co-authored-by: Loïc Estève <[email protected]>
1 parent 7836435 commit 06e566e

File tree

3 files changed

+97
-7
lines changed

3 files changed

+97
-7
lines changed

doc/whats_new/v1.5.rst

+8
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ Changelog
5454
by passing a function in place of a strategy name.
5555
:pr:`28053` by :user:`Mark Elliot <mark-thm>`.
5656

57+
:mod:`sklearn.feature_selection`
58+
................................
59+
60+
- |Enhancement| :func:`feature_selection.mutual_info_regression` and
61+
:func:`feature_selection.mutual_info_classif` now support `n_jobs` parameter.
62+
:pr:`28085` by :user:`Neto Menoci <netomenoci>` and
63+
:user:`Florin Andrei <FlorinAndrei>`.
64+
5765
:mod:`sklearn.metrics`
5866
......................
5967

sklearn/feature_selection/_mutual_info.py

+73-7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ..utils import check_random_state
1414
from ..utils._param_validation import Interval, StrOptions, validate_params
1515
from ..utils.multiclass import check_classification_targets
16+
from ..utils.parallel import Parallel, delayed
1617
from ..utils.validation import check_array, check_X_y
1718

1819

@@ -201,11 +202,13 @@ def _iterate_columns(X, columns=None):
201202
def _estimate_mi(
202203
X,
203204
y,
205+
*,
204206
discrete_features="auto",
205207
discrete_target=False,
206208
n_neighbors=3,
207209
copy=True,
208210
random_state=None,
211+
n_jobs=None,
209212
):
210213
"""Estimate mutual information between the features and the target.
211214
@@ -242,6 +245,16 @@ def _estimate_mi(
242245
Pass an int for reproducible results across multiple function calls.
243246
See :term:`Glossary <random_state>`.
244247
248+
n_jobs : int, default=None
249+
The number of jobs to use for computing the mutual information.
250+
The parallelization is done on the columns of `X`.
251+
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
252+
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
253+
for more details.
254+
255+
.. versionadded:: 1.5
256+
257+
245258
Returns
246259
-------
247260
mi : ndarray, shape (n_features,)
@@ -301,10 +314,10 @@ def _estimate_mi(
301314
* rng.standard_normal(size=n_samples)
302315
)
303316

304-
mi = [
305-
_compute_mi(x, y, discrete_feature, discrete_target, n_neighbors)
317+
mi = Parallel(n_jobs=n_jobs)(
318+
delayed(_compute_mi)(x, y, discrete_feature, discrete_target, n_neighbors)
306319
for x, discrete_feature in zip(_iterate_columns(X), discrete_mask)
307-
]
320+
)
308321

309322
return np.array(mi)
310323

@@ -317,11 +330,19 @@ def _estimate_mi(
317330
"n_neighbors": [Interval(Integral, 1, None, closed="left")],
318331
"copy": ["boolean"],
319332
"random_state": ["random_state"],
333+
"n_jobs": [Integral, None],
320334
},
321335
prefer_skip_nested_validation=True,
322336
)
323337
def mutual_info_regression(
324-
X, y, *, discrete_features="auto", n_neighbors=3, copy=True, random_state=None
338+
X,
339+
y,
340+
*,
341+
discrete_features="auto",
342+
n_neighbors=3,
343+
copy=True,
344+
random_state=None,
345+
n_jobs=None,
325346
):
326347
"""Estimate mutual information for a continuous target variable.
327348
@@ -367,6 +388,16 @@ def mutual_info_regression(
367388
Pass an int for reproducible results across multiple function calls.
368389
See :term:`Glossary <random_state>`.
369390
391+
n_jobs : int, default=None
392+
The number of jobs to use for computing the mutual information.
393+
The parallelization is done on the columns of `X`.
394+
395+
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
396+
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
397+
for more details.
398+
399+
.. versionadded:: 1.5
400+
370401
Returns
371402
-------
372403
mi : ndarray, shape (n_features,)
@@ -407,7 +438,16 @@ def mutual_info_regression(
407438
>>> mutual_info_regression(X, y)
408439
array([0.1..., 2.6... , 0.0...])
409440
"""
410-
return _estimate_mi(X, y, discrete_features, False, n_neighbors, copy, random_state)
441+
return _estimate_mi(
442+
X,
443+
y,
444+
discrete_features=discrete_features,
445+
discrete_target=False,
446+
n_neighbors=n_neighbors,
447+
copy=copy,
448+
random_state=random_state,
449+
n_jobs=n_jobs,
450+
)
411451

412452

413453
@validate_params(
@@ -418,11 +458,19 @@ def mutual_info_regression(
418458
"n_neighbors": [Interval(Integral, 1, None, closed="left")],
419459
"copy": ["boolean"],
420460
"random_state": ["random_state"],
461+
"n_jobs": [Integral, None],
421462
},
422463
prefer_skip_nested_validation=True,
423464
)
424465
def mutual_info_classif(
425-
X, y, *, discrete_features="auto", n_neighbors=3, copy=True, random_state=None
466+
X,
467+
y,
468+
*,
469+
discrete_features="auto",
470+
n_neighbors=3,
471+
copy=True,
472+
random_state=None,
473+
n_jobs=None,
426474
):
427475
"""Estimate mutual information for a discrete target variable.
428476
@@ -468,6 +516,15 @@ def mutual_info_classif(
468516
Pass an int for reproducible results across multiple function calls.
469517
See :term:`Glossary <random_state>`.
470518
519+
n_jobs : int, default=None
520+
The number of jobs to use for computing the mutual information.
521+
The parallelization is done on the columns of `X`.
522+
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
523+
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
524+
for more details.
525+
526+
.. versionadded:: 1.5
527+
471528
Returns
472529
-------
473530
mi : ndarray, shape (n_features,)
@@ -511,4 +568,13 @@ def mutual_info_classif(
511568
0. , 0. , 0. , 0. , 0. ])
512569
"""
513570
check_classification_targets(y)
514-
return _estimate_mi(X, y, discrete_features, True, n_neighbors, copy, random_state)
571+
return _estimate_mi(
572+
X,
573+
y,
574+
discrete_features=discrete_features,
575+
discrete_target=True,
576+
n_neighbors=n_neighbors,
577+
copy=copy,
578+
random_state=random_state,
579+
n_jobs=n_jobs,
580+
)

sklearn/feature_selection/tests/test_mutual_info.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
from sklearn.datasets import make_classification, make_regression
45
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
56
from sklearn.feature_selection._mutual_info import _compute_mi
67
from sklearn.utils import check_random_state
@@ -252,3 +253,18 @@ def test_mutual_info_regression_X_int_dtype(global_random_seed):
252253
expected = mutual_info_regression(X_float, y, random_state=global_random_seed)
253254
result = mutual_info_regression(X, y, random_state=global_random_seed)
254255
assert_allclose(result, expected)
256+
257+
258+
@pytest.mark.parametrize(
259+
"mutual_info_func, data_generator",
260+
[
261+
(mutual_info_regression, make_regression),
262+
(mutual_info_classif, make_classification),
263+
],
264+
)
265+
def test_mutual_info_n_jobs(global_random_seed, mutual_info_func, data_generator):
266+
"""Check that results are consistent with different `n_jobs`."""
267+
X, y = data_generator(random_state=global_random_seed)
268+
single_job = mutual_info_func(X, y, random_state=global_random_seed, n_jobs=1)
269+
multi_job = mutual_info_func(X, y, random_state=global_random_seed, n_jobs=2)
270+
assert_allclose(single_job, multi_job)

0 commit comments

Comments
 (0)