Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b1ce4c1

Browse files
StefanieSengeradrinjalaliglemaitre
authored
ENH Add metadata routing for FeatureUnion (#28205)
Co-authored-by: Adrin Jalali <[email protected]> Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent e5b0488 commit b1ce4c1

File tree

6 files changed

+261
-32
lines changed

6 files changed

+261
-32
lines changed

doc/metadata_routing.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ Meta-estimators and functions supporting metadata routing:
305305
- :class:`sklearn.multioutput.MultiOutputRegressor`
306306
- :class:`sklearn.linear_model.OrthogonalMatchingPursuitCV`
307307
- :class:`sklearn.multioutput.RegressorChain`
308+
- :class:`sklearn.pipeline.FeatureUnion`
308309
- :class:`sklearn.pipeline.Pipeline`
309310

310311
Meta-estimators and tools not supporting metadata routing yet:
@@ -323,5 +324,4 @@ Meta-estimators and tools not supporting metadata routing yet:
323324
- :class:`sklearn.model_selection.learning_curve`
324325
- :class:`sklearn.model_selection.permutation_test_score`
325326
- :class:`sklearn.model_selection.validation_curve`
326-
- :class:`sklearn.pipeline.FeatureUnion`
327327
- :class:`sklearn.semi_supervised.SelfTrainingClassifier`

doc/whats_new/v1.5.rst

+6
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ more details.
7171
``**fit_params`` to the underlying estimators via their `fit` methods.
7272
:pr:`27584` by :user:`Stefanie Senger <StefanieSenger>`.
7373

74+
- |Feature| :class:`pipeline.FeatureUnion` now supports metadata routing in its
75+
``fit`` and ``fit_transform`` methods and route metadata to the underlying
76+
transformers' ``fit`` and ``fit_transform``. :pr:`28205` by :user:`Stefanie
77+
Senger <StefanieSenger>`.
78+
79+
7480
Changelog
7581
---------
7682

sklearn/pipeline.py

+103-21
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131
MetadataRouter,
3232
MethodMapping,
3333
_raise_for_params,
34-
_raise_for_unsupported_routing,
3534
_routing_enabled,
36-
_RoutingNotSupportedMixin,
3735
process_routing,
3836
)
3937
from .utils.metaestimators import _BaseComposition, available_if
@@ -1319,7 +1317,7 @@ def _fit_one(transformer, X, y, weight, message_clsname="", message=None, params
13191317
return transformer.fit(X, y, **params["fit"])
13201318

13211319

1322-
class FeatureUnion(_RoutingNotSupportedMixin, TransformerMixin, _BaseComposition):
1320+
class FeatureUnion(TransformerMixin, _BaseComposition):
13231321
"""Concatenates results of multiple transformer objects.
13241322
13251323
This estimator applies a list of transformer objects in parallel to the
@@ -1644,23 +1642,42 @@ def fit(self, X, y=None, **fit_params):
16441642
Targets for supervised learning.
16451643
16461644
**fit_params : dict, default=None
1647-
Parameters to pass to the fit method of the estimator.
1645+
- If `enable_metadata_routing=False` (default):
1646+
Parameters directly passed to the `fit` methods of the
1647+
sub-transformers.
1648+
1649+
- If `enable_metadata_routing=True`:
1650+
Parameters safely routed to the `fit` methods of the
1651+
sub-transformers. See :ref:`Metadata Routing User Guide
1652+
<metadata_routing>` for more details.
1653+
1654+
.. versionchanged:: 1.5
1655+
`**fit_params` can be routed via metadata routing API.
16481656
16491657
Returns
16501658
-------
16511659
self : object
16521660
FeatureUnion class instance.
16531661
"""
1654-
_raise_for_unsupported_routing(self, "fit", **fit_params)
1655-
transformers = self._parallel_func(X, y, fit_params, _fit_one)
1662+
if _routing_enabled():
1663+
routed_params = process_routing(self, "fit", **fit_params)
1664+
else:
1665+
# TODO(SLEP6): remove when metadata routing cannot be disabled.
1666+
routed_params = Bunch()
1667+
for name, _ in self.transformer_list:
1668+
routed_params[name] = Bunch(fit={})
1669+
routed_params[name].fit = fit_params
1670+
1671+
transformers = self._parallel_func(X, y, _fit_one, routed_params)
1672+
16561673
if not transformers:
16571674
# All transformers are None
16581675
return self
16591676

16601677
self._update_transformer_list(transformers)
16611678
return self
16621679

1663-
def fit_transform(self, X, y=None, **fit_params):
1680+
def fit_transform(self, X, y=None, **params):
16641681
"""Fit all transformers, transform the data and concatenate results.
16651682
16661683
Parameters
@@ -1671,8 +1688,18 @@ def fit_transform(self, X, y=None, **fit_params):
16711688
y : array-like of shape (n_samples, n_outputs), default=None
16721689
Targets for supervised learning.
16731690
1674-
**fit_params : dict, default=None
1675-
Parameters to pass to the fit method of the estimator.
1691+
**params : dict, default=None
1692+
- If `enable_metadata_routing=False` (default):
1693+
Parameters directly passed to the `fit` methods of the
1694+
sub-transformers.
1695+
1696+
- If `enable_metadata_routing=True`:
1697+
Parameters safely routed to the `fit` methods of the
1698+
sub-transformers. See :ref:`Metadata Routing User Guide
1699+
<metadata_routing>` for more details.
1700+
1701+
.. versionchanged:: 1.5
1702+
`**params` can now be routed via metadata routing API.
16761703
16771704
Returns
16781705
-------
@@ -1681,7 +1708,21 @@ def fit_transform(self, X, y=None, **fit_params):
16811708
The `hstack` of results of transformers. `sum_n_components` is the
16821709
sum of `n_components` (output dimension) over transformers.
16831710
"""
1684-
results = self._parallel_func(X, y, fit_params, _fit_transform_one)
1711+
if _routing_enabled():
1712+
routed_params = process_routing(self, "fit_transform", **params)
1713+
else:
1714+
# TODO(SLEP6): remove when metadata routing cannot be disabled.
1715+
routed_params = Bunch()
1716+
for name, obj in self.transformer_list:
1717+
if hasattr(obj, "fit_transform"):
1718+
routed_params[name] = Bunch(fit_transform={})
1719+
routed_params[name].fit_transform = params
1720+
else:
1721+
routed_params[name] = Bunch(fit={})
1722+
routed_params[name] = Bunch(transform={})
1723+
routed_params[name].fit = params
1724+
1725+
results = self._parallel_func(X, y, _fit_transform_one, routed_params)
16851726
if not results:
16861727
# All transformers are None
16871728
return np.zeros((X.shape[0], 0))
@@ -1696,15 +1737,13 @@ def _log_message(self, name, idx, total):
16961737
return None
16971738
return "(step %d of %d) Processing %s" % (idx, total, name)
16981739

1699-
def _parallel_func(self, X, y, fit_params, func):
1740+
def _parallel_func(self, X, y, func, routed_params):
17001741
"""Runs func in parallel on X and y"""
17011742
self.transformer_list = list(self.transformer_list)
17021743
self._validate_transformers()
17031744
self._validate_transformer_weights()
17041745
transformers = list(self._iter())
17051746

1706-
params = Bunch(fit=fit_params, fit_transform=fit_params)
1707-
17081747
return Parallel(n_jobs=self.n_jobs)(
17091748
delayed(func)(
17101749
transformer,
@@ -1713,31 +1752,45 @@ def _parallel_func(self, X, y, fit_params, func):
17131752
weight,
17141753
message_clsname="FeatureUnion",
17151754
message=self._log_message(name, idx, len(transformers)),
1716-
params=params,
1755+
params=routed_params[name],
17171756
)
17181757
for idx, (name, transformer, weight) in enumerate(transformers, 1)
17191758
)
17201759

1721-
def transform(self, X):
1760+
def transform(self, X, **params):
17221761
"""Transform X separately by each transformer, concatenate results.
17231762
17241763
Parameters
17251764
----------
17261765
X : iterable or array-like, depending on transformers
17271766
Input data to be transformed.
17281767
1768+
**params : dict, default=None
1769+
1770+
Parameters routed to the `transform` method of the sub-transformers via the
1771+
metadata routing API. See :ref:`Metadata Routing User Guide
1772+
<metadata_routing>` for more details.
1773+
1774+
.. versionadded:: 1.5
1775+
17291776
Returns
17301777
-------
1731-
X_t : array-like or sparse matrix of \
1732-
shape (n_samples, sum_n_components)
1778+
X_t : array-like or sparse matrix of shape (n_samples, sum_n_components)
17331779
The `hstack` of results of transformers. `sum_n_components` is the
17341780
sum of `n_components` (output dimension) over transformers.
17351781
"""
1736-
# TODO(SLEP6): accept **params here in `transform` and route it to the
1737-
# underlying estimators.
1738-
params = Bunch(transform={})
1782+
_raise_for_params(params, self, "transform")
1783+
1784+
if _routing_enabled():
1785+
routed_params = process_routing(self, "transform", **params)
1786+
else:
1787+
# TODO(SLEP6): remove when metadata routing cannot be disabled.
1788+
routed_params = Bunch()
1789+
for name, _ in self.transformer_list:
1790+
routed_params[name] = Bunch(transform={})
1791+
17391792
Xs = Parallel(n_jobs=self.n_jobs)(
1740-
delayed(_transform_one)(trans, X, None, weight, params)
1793+
delayed(_transform_one)(trans, X, None, weight, routed_params[name])
17411794
for name, trans, weight in self._iter()
17421795
)
17431796
if not Xs:
@@ -1793,6 +1846,35 @@ def __getitem__(self, name):
17931846
raise KeyError("Only string keys are supported")
17941847
return self.named_transformers[name]
17951848

1849+
def get_metadata_routing(self):
1850+
"""Get metadata routing of this object.
1851+
1852+
Please check :ref:`User Guide <metadata_routing>` on how the routing
1853+
mechanism works.
1854+
1855+
.. versionadded:: 1.5
1856+
1857+
Returns
1858+
-------
1859+
routing : MetadataRouter
1860+
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
1861+
routing information.
1862+
"""
1863+
router = MetadataRouter(owner=self.__class__.__name__)
1864+
1865+
for name, transformer in self.transformer_list:
1866+
router.add(
1867+
**{name: transformer},
1868+
method_mapping=MethodMapping()
1869+
.add(caller="fit", callee="fit")
1870+
.add(caller="fit_transform", callee="fit_transform")
1871+
.add(caller="fit_transform", callee="fit")
1872+
.add(caller="fit_transform", callee="transform")
1873+
.add(caller="transform", callee="transform"),
1874+
)
1875+
1876+
return router
1877+
17961878

17971879
def make_union(*transformers, n_jobs=None, verbose=False):
17981880
"""Construct a :class:`FeatureUnion` from the given transformers.

sklearn/tests/metadata_routing_common.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ def check_recorded_metadata(obj, method, split_params=tuple(), **kwargs):
5454
sub-estimator's method where metadata is routed to
5555
split_params : tuple, default=empty
5656
specifies any parameters which are to be checked as being a subset
57-
of the original values.
58-
**kwargs : metadata to check
57+
of the original values
58+
**kwargs : dict
59+
passed metadata
5960
"""
6061
records = getattr(obj, "_records", dict()).get(method, dict())
6162
assert set(kwargs.keys()) == set(
@@ -338,6 +339,29 @@ def inverse_transform(self, X, sample_weight=None, metadata=None):
338339
return X
339340

340341

342+
class ConsumingNoFitTransformTransformer(BaseEstimator):
343+
"""A metadata consuming transformer that doesn't inherit from
344+
TransformerMixin, and thus doesn't implement `fit_transform`. Note that
345+
TransformerMixin's `fit_transform` doesn't route metadata to `transform`."""
346+
347+
def __init__(self, registry=None):
348+
self.registry = registry
349+
350+
def fit(self, X, y=None, sample_weight=None, metadata=None):
351+
if self.registry is not None:
352+
self.registry.append(self)
353+
354+
record_metadata(self, "fit", sample_weight=sample_weight, metadata=metadata)
355+
356+
return self
357+
358+
def transform(self, X, sample_weight=None, metadata=None):
359+
record_metadata(
360+
self, "transform", sample_weight=sample_weight, metadata=metadata
361+
)
362+
return X
363+
364+
341365
class ConsumingScorer(_Scorer):
342366
def __init__(self, registry=None):
343367
super().__init__(

sklearn/tests/test_metaestimators_metadata_routing.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
MultiOutputRegressor,
6060
RegressorChain,
6161
)
62-
from sklearn.pipeline import FeatureUnion
6362
from sklearn.semi_supervised import SelfTrainingClassifier
6463
from sklearn.tests.metadata_routing_common import (
6564
ConsumingClassifier,
@@ -362,7 +361,7 @@ def enable_slep006():
362361
363362
The keys are as follows:
364363
365-
- metaestimator: The metaestmator to be tested
364+
- metaestimator: The metaestimator to be tested
366365
- estimator_name: The name of the argument for the sub-estimator
367366
- estimator: The sub-estimator type, either "regressor" or "classifier"
368367
- init_args: The arguments to be passed to the metaestimator's constructor
@@ -398,7 +397,6 @@ def enable_slep006():
398397
UNSUPPORTED_ESTIMATORS = [
399398
AdaBoostClassifier(),
400399
AdaBoostRegressor(),
401-
FeatureUnion([]),
402400
GraphicalLassoCV(),
403401
RFE(ConsumingClassifier()),
404402
RFECV(ConsumingClassifier()),

0 commit comments

Comments
 (0)