31
31
MetadataRouter ,
32
32
MethodMapping ,
33
33
_raise_for_params ,
34
- _raise_for_unsupported_routing ,
35
34
_routing_enabled ,
36
- _RoutingNotSupportedMixin ,
37
35
process_routing ,
38
36
)
39
37
from .utils .metaestimators import _BaseComposition , available_if
@@ -1319,7 +1317,7 @@ def _fit_one(transformer, X, y, weight, message_clsname="", message=None, params
1319
1317
return transformer .fit (X , y , ** params ["fit" ])
1320
1318
1321
1319
1322
- class FeatureUnion (_RoutingNotSupportedMixin , TransformerMixin , _BaseComposition ):
1320
+ class FeatureUnion (TransformerMixin , _BaseComposition ):
1323
1321
"""Concatenates results of multiple transformer objects.
1324
1322
1325
1323
This estimator applies a list of transformer objects in parallel to the
@@ -1644,23 +1642,42 @@ def fit(self, X, y=None, **fit_params):
1644
1642
Targets for supervised learning.
1645
1643
1646
1644
**fit_params : dict, default=None
1647
- Parameters to pass to the fit method of the estimator.
1645
+ - If `enable_metadata_routing=False` (default):
1646
+ Parameters directly passed to the `fit` methods of the
1647
+ sub-transformers.
1648
+
1649
+ - If `enable_metadata_routing=True`:
1650
+ Parameters safely routed to the `fit` methods of the
1651
+ sub-transformers. See :ref:`Metadata Routing User Guide
1652
+ <metadata_routing>` for more details.
1653
+
1654
+ .. versionchanged:: 1.5
1655
+ `**fit_params` can be routed via metadata routing API.
1648
1656
1649
1657
Returns
1650
1658
-------
1651
1659
self : object
1652
1660
FeatureUnion class instance.
1653
1661
"""
1654
- _raise_for_unsupported_routing (self , "fit" , ** fit_params )
1655
- transformers = self ._parallel_func (X , y , fit_params , _fit_one )
1662
+ if _routing_enabled ():
1663
+ routed_params = process_routing (self , "fit" , ** fit_params )
1664
+ else :
1665
+ # TODO(SLEP6): remove when metadata routing cannot be disabled.
1666
+ routed_params = Bunch ()
1667
+ for name , _ in self .transformer_list :
1668
+ routed_params [name ] = Bunch (fit = {})
1669
+ routed_params [name ].fit = fit_params
1670
+
1671
+ transformers = self ._parallel_func (X , y , _fit_one , routed_params )
1672
+
1656
1673
if not transformers :
1657
1674
# All transformers are None
1658
1675
return self
1659
1676
1660
1677
self ._update_transformer_list (transformers )
1661
1678
return self
1662
1679
1663
- def fit_transform (self , X , y = None , ** fit_params ):
1680
+ def fit_transform (self , X , y = None , ** params ):
1664
1681
"""Fit all transformers, transform the data and concatenate results.
1665
1682
1666
1683
Parameters
@@ -1671,8 +1688,18 @@ def fit_transform(self, X, y=None, **fit_params):
1671
1688
y : array-like of shape (n_samples, n_outputs), default=None
1672
1689
Targets for supervised learning.
1673
1690
1674
- **fit_params : dict, default=None
1675
- Parameters to pass to the fit method of the estimator.
1691
+ **params : dict, default=None
1692
+ - If `enable_metadata_routing=False` (default):
1693
+ Parameters directly passed to the `fit` methods of the
1694
+ sub-transformers.
1695
+
1696
+ - If `enable_metadata_routing=True`:
1697
+ Parameters safely routed to the `fit` methods of the
1698
+ sub-transformers. See :ref:`Metadata Routing User Guide
1699
+ <metadata_routing>` for more details.
1700
+
1701
+ .. versionchanged:: 1.5
1702
+ `**params` can now be routed via metadata routing API.
1676
1703
1677
1704
Returns
1678
1705
-------
@@ -1681,7 +1708,21 @@ def fit_transform(self, X, y=None, **fit_params):
1681
1708
The `hstack` of results of transformers. `sum_n_components` is the
1682
1709
sum of `n_components` (output dimension) over transformers.
1683
1710
"""
1684
- results = self ._parallel_func (X , y , fit_params , _fit_transform_one )
1711
+ if _routing_enabled ():
1712
+ routed_params = process_routing (self , "fit_transform" , ** params )
1713
+ else :
1714
+ # TODO(SLEP6): remove when metadata routing cannot be disabled.
1715
+ routed_params = Bunch ()
1716
+ for name , obj in self .transformer_list :
1717
+ if hasattr (obj , "fit_transform" ):
1718
+ routed_params [name ] = Bunch (fit_transform = {})
1719
+ routed_params [name ].fit_transform = params
1720
+ else :
1721
+ routed_params [name ] = Bunch (fit = {})
1722
+ routed_params [name ] = Bunch (transform = {})
1723
+ routed_params [name ].fit = params
1724
+
1725
+ results = self ._parallel_func (X , y , _fit_transform_one , routed_params )
1685
1726
if not results :
1686
1727
# All transformers are None
1687
1728
return np .zeros ((X .shape [0 ], 0 ))
@@ -1696,15 +1737,13 @@ def _log_message(self, name, idx, total):
1696
1737
return None
1697
1738
return "(step %d of %d) Processing %s" % (idx , total , name )
1698
1739
1699
- def _parallel_func (self , X , y , fit_params , func ):
1740
+ def _parallel_func (self , X , y , func , routed_params ):
1700
1741
"""Runs func in parallel on X and y"""
1701
1742
self .transformer_list = list (self .transformer_list )
1702
1743
self ._validate_transformers ()
1703
1744
self ._validate_transformer_weights ()
1704
1745
transformers = list (self ._iter ())
1705
1746
1706
- params = Bunch (fit = fit_params , fit_transform = fit_params )
1707
-
1708
1747
return Parallel (n_jobs = self .n_jobs )(
1709
1748
delayed (func )(
1710
1749
transformer ,
@@ -1713,31 +1752,45 @@ def _parallel_func(self, X, y, fit_params, func):
1713
1752
weight ,
1714
1753
message_clsname = "FeatureUnion" ,
1715
1754
message = self ._log_message (name , idx , len (transformers )),
1716
- params = params ,
1755
+ params = routed_params [ name ] ,
1717
1756
)
1718
1757
for idx , (name , transformer , weight ) in enumerate (transformers , 1 )
1719
1758
)
1720
1759
1721
- def transform (self , X ):
1760
+ def transform (self , X , ** params ):
1722
1761
"""Transform X separately by each transformer, concatenate results.
1723
1762
1724
1763
Parameters
1725
1764
----------
1726
1765
X : iterable or array-like, depending on transformers
1727
1766
Input data to be transformed.
1728
1767
1768
+ **params : dict, default=None
1769
+
1770
+ Parameters routed to the `transform` method of the sub-transformers via the
1771
+ metadata routing API. See :ref:`Metadata Routing User Guide
1772
+ <metadata_routing>` for more details.
1773
+
1774
+ .. versionadded:: 1.5
1775
+
1729
1776
Returns
1730
1777
-------
1731
- X_t : array-like or sparse matrix of \
1732
- shape (n_samples, sum_n_components)
1778
+ X_t : array-like or sparse matrix of shape (n_samples, sum_n_components)
1733
1779
The `hstack` of results of transformers. `sum_n_components` is the
1734
1780
sum of `n_components` (output dimension) over transformers.
1735
1781
"""
1736
- # TODO(SLEP6): accept **params here in `transform` and route it to the
1737
- # underlying estimators.
1738
- params = Bunch (transform = {})
1782
+ _raise_for_params (params , self , "transform" )
1783
+
1784
+ if _routing_enabled ():
1785
+ routed_params = process_routing (self , "transform" , ** params )
1786
+ else :
1787
+ # TODO(SLEP6): remove when metadata routing cannot be disabled.
1788
+ routed_params = Bunch ()
1789
+ for name , _ in self .transformer_list :
1790
+ routed_params [name ] = Bunch (transform = {})
1791
+
1739
1792
Xs = Parallel (n_jobs = self .n_jobs )(
1740
- delayed (_transform_one )(trans , X , None , weight , params )
1793
+ delayed (_transform_one )(trans , X , None , weight , routed_params [ name ] )
1741
1794
for name , trans , weight in self ._iter ()
1742
1795
)
1743
1796
if not Xs :
@@ -1793,6 +1846,35 @@ def __getitem__(self, name):
1793
1846
raise KeyError ("Only string keys are supported" )
1794
1847
return self .named_transformers [name ]
1795
1848
1849
+ def get_metadata_routing (self ):
1850
+ """Get metadata routing of this object.
1851
+
1852
+ Please check :ref:`User Guide <metadata_routing>` on how the routing
1853
+ mechanism works.
1854
+
1855
+ .. versionadded:: 1.5
1856
+
1857
+ Returns
1858
+ -------
1859
+ routing : MetadataRouter
1860
+ A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
1861
+ routing information.
1862
+ """
1863
+ router = MetadataRouter (owner = self .__class__ .__name__ )
1864
+
1865
+ for name , transformer in self .transformer_list :
1866
+ router .add (
1867
+ ** {name : transformer },
1868
+ method_mapping = MethodMapping ()
1869
+ .add (caller = "fit" , callee = "fit" )
1870
+ .add (caller = "fit_transform" , callee = "fit_transform" )
1871
+ .add (caller = "fit_transform" , callee = "fit" )
1872
+ .add (caller = "fit_transform" , callee = "transform" )
1873
+ .add (caller = "transform" , callee = "transform" ),
1874
+ )
1875
+
1876
+ return router
1877
+
1796
1878
1797
1879
def make_union (* transformers , n_jobs = None , verbose = False ):
1798
1880
"""Construct a :class:`FeatureUnion` from the given transformers.
0 commit comments