diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 83eaded3b90ec..e2b18cd0149a2 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -521,6 +521,10 @@ Changelog be used when one of the transformers in the :class:`pipeline.FeatureUnion` is `"passthrough"`. :pr:`24058` by :user:`Diederik Perdok ` +- |Enhancement| The :class:`FeatureUnion` class now has a `named_transformers` + attribute for accessing transformers by name. + :pr:`20331` by :user:`Christopher Flynn `. + :mod:`sklearn.preprocessing` ............................ diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 7b9e5327de906..8da71119f381d 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -959,6 +959,14 @@ class FeatureUnion(TransformerMixin, _BaseComposition): Attributes ---------- + named_transformers : :class:`~sklearn.utils.Bunch` + Dictionary-like object, with the following attributes. + Read-only attribute to access any transformer parameter by user + given name. Keys are transformer names and values are + transformer parameters. + + .. versionadded:: 1.2 + n_features_in_ : int Number of features seen during :term:`fit`. Only defined if the underlying first transformer in `transformer_list` exposes such an @@ -1017,6 +1025,11 @@ def set_output(self, transform=None): _safe_set_output(step, transform=transform) return self + @property + def named_transformers(self): + # Use Bunch object to improve autocomplete + return Bunch(**dict(self.transformer_list)) + def get_params(self, deep=True): """Get parameters for this estimator. diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 6bed7e520f438..07e3f7170efdf 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -524,6 +524,19 @@ def test_feature_union(): fs.fit(X, y) +def test_feature_union_named_transformers(): + """Check the behaviour of `named_transformers` attribute.""" + transf = Transf() + noinvtransf = NoInvTransf() + fs = FeatureUnion([("transf", transf), ("noinvtransf", noinvtransf)]) + assert fs.named_transformers["transf"] == transf + assert fs.named_transformers["noinvtransf"] == noinvtransf + + # test named attribute + assert fs.named_transformers.transf == transf + assert fs.named_transformers.noinvtransf == noinvtransf + + def test_make_union(): pca = PCA(svd_solver="full") mock = Transf()