From 4471af00445a9605078059df465837d1c032d8b5 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 1 Dec 2022 16:06:26 -0500 Subject: [PATCH 1/4] ENH Adds FeatureUnion.__getitem__ to access transformers --- doc/whats_new/v1.3.rst | 6 ++++++ sklearn/pipeline.py | 4 ++++ sklearn/tests/test_pipeline.py | 18 ++++++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 6012ac7a336a4..81eeee0c7a5ee 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -36,6 +36,12 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. +:mod:`sklearn.pipeline` +....................... +- |Feature| :class:`pipeline.FeatureUnion` can now use indexing notation (e.g. + `feature_union["scalar"]`) to access transformers by name. :pr:`xxxxx` by + `Thomas Fan`_. + Code and Documentation Contributors ----------------------------------- diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index b3a4d180a4c68..5f13f1316ff64 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -1286,6 +1286,10 @@ def _sk_visual_block_(self): names, transformers = zip(*self.transformer_list) return _VisualBlock("parallel", transformers, names=names) + def __getitem__(self, name): + """Return transformer with name.""" + return self.named_transformers[name] + def make_union(*transformers, n_jobs=None, verbose=False): """Construct a FeatureUnion from the given transformers. diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 07e3f7170efdf..423c3260a4fd8 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -1658,3 +1658,21 @@ def test_feature_union_set_output(): assert isinstance(X_trans, pd.DataFrame) assert_array_equal(X_trans.columns, union.get_feature_names_out()) assert_array_equal(X_trans.index, X_test.index) + + +def test_feature_union_getitem(): + """Check FeatureUnion.__getitem__ returns expected results.""" + scalar = StandardScaler() + pca = PCA() + union = FeatureUnion( + [ + ("scalar", scalar), + ("pca", pca), + ("pass", "passthrough"), + ("drop_me", "drop"), + ] + ) + assert union["scalar"] is scalar + assert union["pca"] is pca + assert union["scalar"] is scalar + assert union["pass"] == "passthrough" From 11ccc95169a8e290695e711395021377765abd58 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 1 Dec 2022 17:38:20 -0500 Subject: [PATCH 2/4] DOC Adds whats new number --- doc/whats_new/v1.3.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 81eeee0c7a5ee..a7838f7c2ec8f 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -39,7 +39,7 @@ Changelog :mod:`sklearn.pipeline` ....................... - |Feature| :class:`pipeline.FeatureUnion` can now use indexing notation (e.g. - `feature_union["scalar"]`) to access transformers by name. :pr:`xxxxx` by + `feature_union["scalar"]`) to access transformers by name. :pr:`25093` by `Thomas Fan`_. Code and Documentation Contributors From f5a59304c876a164c73d647a3d8d6e7dc628bd6c Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 2 Dec 2022 15:38:48 -0500 Subject: [PATCH 3/4] Update sklearn/tests/test_pipeline.py Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com> --- sklearn/tests/test_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 423c3260a4fd8..71ff0cbd8990e 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -1674,5 +1674,5 @@ def test_feature_union_getitem(): ) assert union["scalar"] is scalar assert union["pca"] is pca - assert union["scalar"] is scalar assert union["pass"] == "passthrough" + assert union["drop_me"] == "drop" From 07c55cc44a9fc849b10a05477d4aeb57936ac870 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 5 Dec 2022 11:45:54 -0500 Subject: [PATCH 4/4] ENH Validate string keys in __getitem__ --- sklearn/pipeline.py | 2 ++ sklearn/tests/test_pipeline.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 5f13f1316ff64..cef3288b85439 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -1288,6 +1288,8 @@ def _sk_visual_block_(self): def __getitem__(self, name): """Return transformer with name.""" + if not isinstance(name, str): + raise KeyError("Only string keys are supported") return self.named_transformers[name] diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 71ff0cbd8990e..eab7d8027b3cd 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -1676,3 +1676,14 @@ def test_feature_union_getitem(): assert union["pca"] is pca assert union["pass"] == "passthrough" assert union["drop_me"] == "drop" + + +@pytest.mark.parametrize("key", [0, slice(0, 2)]) +def test_feature_union_getitem_error(key): + """Raise error when __getitem__ gets a non-string input.""" + + union = FeatureUnion([("scalar", StandardScaler()), ("pca", PCA())]) + + msg = "Only string keys are supported" + with pytest.raises(KeyError, match=msg): + union[key]