13
13
from ..utils import check_random_state
14
14
from ..utils ._param_validation import Interval , StrOptions , validate_params
15
15
from ..utils .multiclass import check_classification_targets
16
+ from ..utils .parallel import Parallel , delayed
16
17
from ..utils .validation import check_array , check_X_y
17
18
18
19
@@ -201,11 +202,13 @@ def _iterate_columns(X, columns=None):
201
202
def _estimate_mi (
202
203
X ,
203
204
y ,
205
+ * ,
204
206
discrete_features = "auto" ,
205
207
discrete_target = False ,
206
208
n_neighbors = 3 ,
207
209
copy = True ,
208
210
random_state = None ,
211
+ n_jobs = None ,
209
212
):
210
213
"""Estimate mutual information between the features and the target.
211
214
@@ -242,6 +245,16 @@ def _estimate_mi(
242
245
Pass an int for reproducible results across multiple function calls.
243
246
See :term:`Glossary <random_state>`.
244
247
248
+ n_jobs : int, default=None
249
+ The number of jobs to use for computing the mutual information.
250
+ The parallelization is done on the columns of `X`.
251
+ ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
252
+ ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
253
+ for more details.
254
+
255
+ .. versionadded:: 1.5
256
+
257
+
245
258
Returns
246
259
-------
247
260
mi : ndarray, shape (n_features,)
@@ -301,10 +314,10 @@ def _estimate_mi(
301
314
* rng .standard_normal (size = n_samples )
302
315
)
303
316
304
- mi = [
305
- _compute_mi (x , y , discrete_feature , discrete_target , n_neighbors )
317
+ mi = Parallel ( n_jobs = n_jobs )(
318
+ delayed ( _compute_mi ) (x , y , discrete_feature , discrete_target , n_neighbors )
306
319
for x , discrete_feature in zip (_iterate_columns (X ), discrete_mask )
307
- ]
320
+ )
308
321
309
322
return np .array (mi )
310
323
@@ -317,11 +330,19 @@ def _estimate_mi(
317
330
"n_neighbors" : [Interval (Integral , 1 , None , closed = "left" )],
318
331
"copy" : ["boolean" ],
319
332
"random_state" : ["random_state" ],
333
+ "n_jobs" : [Integral , None ],
320
334
},
321
335
prefer_skip_nested_validation = True ,
322
336
)
323
337
def mutual_info_regression (
324
- X , y , * , discrete_features = "auto" , n_neighbors = 3 , copy = True , random_state = None
338
+ X ,
339
+ y ,
340
+ * ,
341
+ discrete_features = "auto" ,
342
+ n_neighbors = 3 ,
343
+ copy = True ,
344
+ random_state = None ,
345
+ n_jobs = None ,
325
346
):
326
347
"""Estimate mutual information for a continuous target variable.
327
348
@@ -367,6 +388,16 @@ def mutual_info_regression(
367
388
Pass an int for reproducible results across multiple function calls.
368
389
See :term:`Glossary <random_state>`.
369
390
391
+ n_jobs : int, default=None
392
+ The number of jobs to use for computing the mutual information.
393
+ The parallelization is done on the columns of `X`.
394
+
395
+ ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
396
+ ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
397
+ for more details.
398
+
399
+ .. versionadded:: 1.5
400
+
370
401
Returns
371
402
-------
372
403
mi : ndarray, shape (n_features,)
@@ -407,7 +438,16 @@ def mutual_info_regression(
407
438
>>> mutual_info_regression(X, y)
408
439
array([0.1..., 2.6... , 0.0...])
409
440
"""
410
- return _estimate_mi (X , y , discrete_features , False , n_neighbors , copy , random_state )
441
+ return _estimate_mi (
442
+ X ,
443
+ y ,
444
+ discrete_features = discrete_features ,
445
+ discrete_target = False ,
446
+ n_neighbors = n_neighbors ,
447
+ copy = copy ,
448
+ random_state = random_state ,
449
+ n_jobs = n_jobs ,
450
+ )
411
451
412
452
413
453
@validate_params (
@@ -418,11 +458,19 @@ def mutual_info_regression(
418
458
"n_neighbors" : [Interval (Integral , 1 , None , closed = "left" )],
419
459
"copy" : ["boolean" ],
420
460
"random_state" : ["random_state" ],
461
+ "n_jobs" : [Integral , None ],
421
462
},
422
463
prefer_skip_nested_validation = True ,
423
464
)
424
465
def mutual_info_classif (
425
- X , y , * , discrete_features = "auto" , n_neighbors = 3 , copy = True , random_state = None
466
+ X ,
467
+ y ,
468
+ * ,
469
+ discrete_features = "auto" ,
470
+ n_neighbors = 3 ,
471
+ copy = True ,
472
+ random_state = None ,
473
+ n_jobs = None ,
426
474
):
427
475
"""Estimate mutual information for a discrete target variable.
428
476
@@ -468,6 +516,15 @@ def mutual_info_classif(
468
516
Pass an int for reproducible results across multiple function calls.
469
517
See :term:`Glossary <random_state>`.
470
518
519
+ n_jobs : int, default=None
520
+ The number of jobs to use for computing the mutual information.
521
+ The parallelization is done on the columns of `X`.
522
+ ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
523
+ ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
524
+ for more details.
525
+
526
+ .. versionadded:: 1.5
527
+
471
528
Returns
472
529
-------
473
530
mi : ndarray, shape (n_features,)
@@ -511,4 +568,13 @@ def mutual_info_classif(
511
568
0. , 0. , 0. , 0. , 0. ])
512
569
"""
513
570
check_classification_targets (y )
514
- return _estimate_mi (X , y , discrete_features , True , n_neighbors , copy , random_state )
571
+ return _estimate_mi (
572
+ X ,
573
+ y ,
574
+ discrete_features = discrete_features ,
575
+ discrete_target = True ,
576
+ n_neighbors = n_neighbors ,
577
+ copy = copy ,
578
+ random_state = random_state ,
579
+ n_jobs = n_jobs ,
580
+ )
0 commit comments