Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit be35d8c

Browse files
ENH Add metadata routing for IterativeImputer (scikit-learn#28187)
1 parent 9a6e6dd commit be35d8c

File tree

5 files changed

+88
-11
lines changed

5 files changed

+88
-11
lines changed

doc/metadata_routing.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ Meta-estimators and functions supporting metadata routing:
252252
- :class:`sklearn.calibration.CalibratedClassifierCV`
253253
- :class:`sklearn.compose.ColumnTransformer`
254254
- :class:`sklearn.feature_selection.SelectFromModel`
255+
- :class:`sklearn.impute.IterativeImputer`
255256
- :class:`sklearn.linear_model.ElasticNetCV`
256257
- :class:`sklearn.linear_model.LarsCV`
257258
- :class:`sklearn.linear_model.LassoCV`
@@ -291,7 +292,6 @@ Meta-estimators and tools not supporting metadata routing yet:
291292
- :class:`sklearn.feature_selection.RFE`
292293
- :class:`sklearn.feature_selection.RFECV`
293294
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
294-
- :class:`sklearn.impute.IterativeImputer`
295295
- :class:`sklearn.linear_model.RANSACRegressor`
296296
- :class:`sklearn.linear_model.RidgeClassifierCV`
297297
- :class:`sklearn.linear_model.RidgeCV`

doc/whats_new/v1.5.rst

+11
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ Meson is now supported as a build backend, see :ref:`Building with Meson
3232

3333
TODO Fill more details before the 1.5 release, when the Meson story has settled down.
3434

35+
Metadata Routing
36+
----------------
37+
38+
The following models now support metadata routing in one or more or their
39+
methods. Refer to the :ref:`Metadata Routing User Guide <metadata_routing>` for
40+
more details.
41+
42+
- |Feature| :class:`impute.IterativeImputer` now supports metadata routing in
43+
its `fit` method. :pr:`28187` by :user:`Stefanie Senger <StefanieSenger>`.
44+
45+
3546
Changelog
3647
---------
3748

sklearn/feature_selection/_from_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,12 @@ def fit(self, X, y=None, **fit_params):
349349
**fit_params : dict
350350
- If `enable_metadata_routing=False` (default):
351351
352-
Parameters directly passed to the `partial_fit` method of the
352+
Parameters directly passed to the `fit` method of the
353353
sub-estimator. They are ignored if `prefit=True`.
354354
355355
- If `enable_metadata_routing=True`:
356356
357-
Parameters safely routed to the `partial_fit` method of the
357+
Parameters safely routed to the `fit` method of the
358358
sub-estimator. They are ignored if `prefit=True`.
359359
360360
.. versionchanged:: 1.4

sklearn/impute/_iterative.py

+65-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
)
1919
from ..utils._mask import _get_mask
2020
from ..utils._param_validation import HasMethods, Interval, StrOptions
21-
from ..utils.metadata_routing import _RoutingNotSupportedMixin
21+
from ..utils.metadata_routing import (
22+
MetadataRouter,
23+
MethodMapping,
24+
_raise_for_params,
25+
process_routing,
26+
)
2227
from ..utils.validation import FLOAT_DTYPES, _check_feature_names_in, check_is_fitted
2328
from ._base import SimpleImputer, _BaseImputer, _check_inputs_dtype
2429

@@ -47,7 +52,7 @@ def _assign_where(X1, X2, cond):
4752
X1[cond] = X2[cond]
4853

4954

50-
class IterativeImputer(_RoutingNotSupportedMixin, _BaseImputer):
55+
class IterativeImputer(_BaseImputer):
5156
"""Multivariate imputer that estimates each feature from all the others.
5257
5358
A strategy for imputing missing values by modeling each feature with
@@ -349,6 +354,7 @@ def _impute_one_feature(
349354
neighbor_feat_idx,
350355
estimator=None,
351356
fit_mode=True,
357+
params=None,
352358
):
353359
"""Impute a single feature from the others provided.
354360
@@ -380,6 +386,9 @@ def _impute_one_feature(
380386
fit_mode : boolean, default=True
381387
Whether to fit and predict with the estimator or just predict.
382388
389+
params : dict
390+
Additional params routed to the individual estimator.
391+
383392
Returns
384393
-------
385394
X_filled : ndarray
@@ -410,7 +419,7 @@ def _impute_one_feature(
410419
~missing_row_mask,
411420
axis=0,
412421
)
413-
estimator.fit(X_train, y_train)
422+
estimator.fit(X_train, y_train, **params)
414423

415424
# if no missing values, don't predict
416425
if np.sum(missing_row_mask) == 0:
@@ -685,7 +694,7 @@ def _validate_limit(limit, limit_type, n_features):
685694
# IterativeImputer.estimator is not validated yet
686695
prefer_skip_nested_validation=False
687696
)
688-
def fit_transform(self, X, y=None):
697+
def fit_transform(self, X, y=None, **params):
689698
"""Fit the imputer on `X` and return the transformed `X`.
690699
691700
Parameters
@@ -697,11 +706,29 @@ def fit_transform(self, X, y=None):
697706
y : Ignored
698707
Not used, present for API consistency by convention.
699708
709+
**params : dict
710+
Parameters routed to the `fit` method of the sub-estimator via the
711+
metadata routing API.
712+
713+
.. versionadded:: 1.5
714+
Only available if
715+
`sklearn.set_config(enable_metadata_routing=True)` is set. See
716+
:ref:`Metadata Routing User Guide <metadata_routing>` for more
717+
details.
718+
700719
Returns
701720
-------
702721
Xt : array-like, shape (n_samples, n_features)
703722
The imputed input data.
704723
"""
724+
_raise_for_params(params, self, "fit")
725+
726+
routed_params = process_routing(
727+
self,
728+
"fit",
729+
**params,
730+
)
731+
705732
self.random_state_ = getattr(
706733
self, "random_state_", check_random_state(self.random_state)
707734
)
@@ -728,7 +755,7 @@ def fit_transform(self, X, y=None):
728755
self.n_iter_ = 0
729756
return super()._concatenate_indicator(Xt, X_indicator)
730757

731-
# Edge case: a single feature. We return the initial ...
758+
# Edge case: a single feature, we return the initial imputation.
732759
if Xt.shape[1] == 1:
733760
self.n_iter_ = 0
734761
return super()._concatenate_indicator(Xt, X_indicator)
@@ -770,6 +797,7 @@ def fit_transform(self, X, y=None):
770797
neighbor_feat_idx,
771798
estimator=None,
772799
fit_mode=True,
800+
params=routed_params.estimator.fit,
773801
)
774802
estimator_triplet = _ImputerTriplet(
775803
feat_idx, neighbor_feat_idx, estimator
@@ -860,7 +888,7 @@ def transform(self, X):
860888

861889
return super()._concatenate_indicator(Xt, X_indicator)
862890

863-
def fit(self, X, y=None):
891+
def fit(self, X, y=None, **fit_params):
864892
"""Fit the imputer on `X` and return self.
865893
866894
Parameters
@@ -872,12 +900,22 @@ def fit(self, X, y=None):
872900
y : Ignored
873901
Not used, present for API consistency by convention.
874902
903+
**fit_params : dict
904+
Parameters routed to the `fit` method of the sub-estimator via the
905+
metadata routing API.
906+
907+
.. versionadded:: 1.5
908+
Only available if
909+
`sklearn.set_config(enable_metadata_routing=True)` is set. See
910+
:ref:`Metadata Routing User Guide <metadata_routing>` for more
911+
details.
912+
875913
Returns
876914
-------
877915
self : object
878916
Fitted estimator.
879917
"""
880-
self.fit_transform(X)
918+
self.fit_transform(X, **fit_params)
881919
return self
882920

883921
def get_feature_names_out(self, input_features=None):
@@ -904,3 +942,23 @@ def get_feature_names_out(self, input_features=None):
904942
input_features = _check_feature_names_in(self, input_features)
905943
names = self.initial_imputer_.get_feature_names_out(input_features)
906944
return self._concatenate_indicator_feature_names_out(names, input_features)
945+
946+
def get_metadata_routing(self):
947+
"""Get metadata routing of this object.
948+
949+
Please check :ref:`User Guide <metadata_routing>` on how the routing
950+
mechanism works.
951+
952+
.. versionadded:: 1.5
953+
954+
Returns
955+
-------
956+
routing : MetadataRouter
957+
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
958+
routing information.
959+
"""
960+
router = MetadataRouter(owner=self.__class__.__name__).add(
961+
estimator=self.estimator,
962+
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
963+
)
964+
return router

sklearn/tests/test_metaestimators_metadata_routing.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,15 @@ def enable_slep006():
289289
"cv_name": "cv",
290290
"cv_routing_methods": ["fit"],
291291
},
292+
{
293+
"metaestimator": IterativeImputer,
294+
"estimator_name": "estimator",
295+
"estimator": ConsumingRegressor,
296+
"init_args": {"skip_complete": False},
297+
"X": X,
298+
"y": y,
299+
"estimator_routing_methods": ["fit"],
300+
},
292301
]
293302
"""List containing all metaestimators to be tested and their settings
294303
@@ -331,7 +340,6 @@ def enable_slep006():
331340
BaggingRegressor(),
332341
FeatureUnion([]),
333342
GraphicalLassoCV(),
334-
IterativeImputer(),
335343
RANSACRegressor(),
336344
RFE(ConsumingClassifier()),
337345
RFECV(ConsumingClassifier()),

0 commit comments

Comments
 (0)