18
18
)
19
19
from ..utils ._mask import _get_mask
20
20
from ..utils ._param_validation import HasMethods , Interval , StrOptions
21
- from ..utils .metadata_routing import _RoutingNotSupportedMixin
21
+ from ..utils .metadata_routing import (
22
+ MetadataRouter ,
23
+ MethodMapping ,
24
+ _raise_for_params ,
25
+ process_routing ,
26
+ )
22
27
from ..utils .validation import FLOAT_DTYPES , _check_feature_names_in , check_is_fitted
23
28
from ._base import SimpleImputer , _BaseImputer , _check_inputs_dtype
24
29
@@ -47,7 +52,7 @@ def _assign_where(X1, X2, cond):
47
52
X1 [cond ] = X2 [cond ]
48
53
49
54
50
- class IterativeImputer (_RoutingNotSupportedMixin , _BaseImputer ):
55
+ class IterativeImputer (_BaseImputer ):
51
56
"""Multivariate imputer that estimates each feature from all the others.
52
57
53
58
A strategy for imputing missing values by modeling each feature with
@@ -349,6 +354,7 @@ def _impute_one_feature(
349
354
neighbor_feat_idx ,
350
355
estimator = None ,
351
356
fit_mode = True ,
357
+ params = None ,
352
358
):
353
359
"""Impute a single feature from the others provided.
354
360
@@ -380,6 +386,9 @@ def _impute_one_feature(
380
386
fit_mode : boolean, default=True
381
387
Whether to fit and predict with the estimator or just predict.
382
388
389
+ params : dict
390
+ Additional params routed to the individual estimator.
391
+
383
392
Returns
384
393
-------
385
394
X_filled : ndarray
@@ -410,7 +419,7 @@ def _impute_one_feature(
410
419
~ missing_row_mask ,
411
420
axis = 0 ,
412
421
)
413
- estimator .fit (X_train , y_train )
422
+ estimator .fit (X_train , y_train , ** params )
414
423
415
424
# if no missing values, don't predict
416
425
if np .sum (missing_row_mask ) == 0 :
@@ -685,7 +694,7 @@ def _validate_limit(limit, limit_type, n_features):
685
694
# IterativeImputer.estimator is not validated yet
686
695
prefer_skip_nested_validation = False
687
696
)
688
- def fit_transform (self , X , y = None ):
697
+ def fit_transform (self , X , y = None , ** params ):
689
698
"""Fit the imputer on `X` and return the transformed `X`.
690
699
691
700
Parameters
@@ -697,11 +706,29 @@ def fit_transform(self, X, y=None):
697
706
y : Ignored
698
707
Not used, present for API consistency by convention.
699
708
709
+ **params : dict
710
+ Parameters routed to the `fit` method of the sub-estimator via the
711
+ metadata routing API.
712
+
713
+ .. versionadded:: 1.5
714
+ Only available if
715
+ `sklearn.set_config(enable_metadata_routing=True)` is set. See
716
+ :ref:`Metadata Routing User Guide <metadata_routing>` for more
717
+ details.
718
+
700
719
Returns
701
720
-------
702
721
Xt : array-like, shape (n_samples, n_features)
703
722
The imputed input data.
704
723
"""
724
+ _raise_for_params (params , self , "fit" )
725
+
726
+ routed_params = process_routing (
727
+ self ,
728
+ "fit" ,
729
+ ** params ,
730
+ )
731
+
705
732
self .random_state_ = getattr (
706
733
self , "random_state_" , check_random_state (self .random_state )
707
734
)
@@ -728,7 +755,7 @@ def fit_transform(self, X, y=None):
728
755
self .n_iter_ = 0
729
756
return super ()._concatenate_indicator (Xt , X_indicator )
730
757
731
- # Edge case: a single feature. We return the initial .. .
758
+ # Edge case: a single feature, we return the initial imputation .
732
759
if Xt .shape [1 ] == 1 :
733
760
self .n_iter_ = 0
734
761
return super ()._concatenate_indicator (Xt , X_indicator )
@@ -770,6 +797,7 @@ def fit_transform(self, X, y=None):
770
797
neighbor_feat_idx ,
771
798
estimator = None ,
772
799
fit_mode = True ,
800
+ params = routed_params .estimator .fit ,
773
801
)
774
802
estimator_triplet = _ImputerTriplet (
775
803
feat_idx , neighbor_feat_idx , estimator
@@ -860,7 +888,7 @@ def transform(self, X):
860
888
861
889
return super ()._concatenate_indicator (Xt , X_indicator )
862
890
863
- def fit (self , X , y = None ):
891
+ def fit (self , X , y = None , ** fit_params ):
864
892
"""Fit the imputer on `X` and return self.
865
893
866
894
Parameters
@@ -872,12 +900,22 @@ def fit(self, X, y=None):
872
900
y : Ignored
873
901
Not used, present for API consistency by convention.
874
902
903
+ **fit_params : dict
904
+ Parameters routed to the `fit` method of the sub-estimator via the
905
+ metadata routing API.
906
+
907
+ .. versionadded:: 1.5
908
+ Only available if
909
+ `sklearn.set_config(enable_metadata_routing=True)` is set. See
910
+ :ref:`Metadata Routing User Guide <metadata_routing>` for more
911
+ details.
912
+
875
913
Returns
876
914
-------
877
915
self : object
878
916
Fitted estimator.
879
917
"""
880
- self .fit_transform (X )
918
+ self .fit_transform (X , ** fit_params )
881
919
return self
882
920
883
921
def get_feature_names_out (self , input_features = None ):
@@ -904,3 +942,23 @@ def get_feature_names_out(self, input_features=None):
904
942
input_features = _check_feature_names_in (self , input_features )
905
943
names = self .initial_imputer_ .get_feature_names_out (input_features )
906
944
return self ._concatenate_indicator_feature_names_out (names , input_features )
945
+
946
+ def get_metadata_routing (self ):
947
+ """Get metadata routing of this object.
948
+
949
+ Please check :ref:`User Guide <metadata_routing>` on how the routing
950
+ mechanism works.
951
+
952
+ .. versionadded:: 1.5
953
+
954
+ Returns
955
+ -------
956
+ routing : MetadataRouter
957
+ A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
958
+ routing information.
959
+ """
960
+ router = MetadataRouter (owner = self .__class__ .__name__ ).add (
961
+ estimator = self .estimator ,
962
+ method_mapping = MethodMapping ().add (callee = "fit" , caller = "fit" ),
963
+ )
964
+ return router
0 commit comments