From ff29e5e70dfd02fe5e21fc966e45ef0766b47497 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 6 Sep 2021 21:50:05 -0400 Subject: [PATCH 1/3] BUG Fixes FunctionTransformer validation in inverse_transform --- .../preprocessing/_function_transformer.py | 13 +++++----- .../tests/test_function_transformer.py | 24 +++++++++++++++++++ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/sklearn/preprocessing/_function_transformer.py b/sklearn/preprocessing/_function_transformer.py index 20ee90f5f253f..9ec781196e27b 100644 --- a/sklearn/preprocessing/_function_transformer.py +++ b/sklearn/preprocessing/_function_transformer.py @@ -1,7 +1,7 @@ import warnings from ..base import BaseEstimator, TransformerMixin -from ..utils.validation import _allclose_dense_sparse +from ..utils.validation import _allclose_dense_sparse, check_array def _identity(X): @@ -110,9 +110,9 @@ def __init__( self.kw_args = kw_args self.inv_kw_args = inv_kw_args - def _check_input(self, X): + def _check_input(self, X, *, reset): if self.validate: - return self._validate_data(X, accept_sparse=self.accept_sparse) + return self._validate_data(X, accept_sparse=self.accept_sparse, reset=reset) return X def _check_inverse_transform(self, X): @@ -146,7 +146,7 @@ def fit(self, X, y=None): self : object FunctionTransformer class instance. """ - X = self._check_input(X) + X = self._check_input(X, reset=True) if self.check_inverse and not (self.func is None or self.inverse_func is None): self._check_inverse_transform(X) return self @@ -164,6 +164,7 @@ def transform(self, X): X_out : array-like, shape (n_samples, n_features) Transformed input. """ + X = self._check_input(X, reset=False) return self._transform(X, func=self.func, kw_args=self.kw_args) def inverse_transform(self, X): @@ -179,11 +180,11 @@ def inverse_transform(self, X): X_out : array-like, shape (n_samples, n_features) Transformed input. """ + if self.validate: + X = check_array(X, accept_sparse=self.accept_sparse) return self._transform(X, func=self.inverse_func, kw_args=self.inv_kw_args) def _transform(self, X, func=None, kw_args=None): - X = self._check_input(X) - if func is None: func = _identity diff --git a/sklearn/preprocessing/tests/test_function_transformer.py b/sklearn/preprocessing/tests/test_function_transformer.py index b3e517ac0c36c..b1ba9ebe6b762 100644 --- a/sklearn/preprocessing/tests/test_function_transformer.py +++ b/sklearn/preprocessing/tests/test_function_transformer.py @@ -174,3 +174,27 @@ def test_function_transformer_frame(): transformer = FunctionTransformer() X_df_trans = transformer.fit_transform(X_df) assert hasattr(X_df_trans, "loc") + + +def test_function_transformer_validate_inverse(): + """Test that function transformer does not reset estimator in + `inverse_transform`.""" + + def add_constant_feature(X): + X_one = np.ones((X.shape[0], 1)) + return np.concatenate((X, X_one), axis=1) + + def inverse_add_constant(X): + return X[:, :-1] + + X = np.array([[1, 2], [3, 4], [3, 4]]) + trans = FunctionTransformer( + func=add_constant_feature, + inverse_func=inverse_add_constant, + validate=True, + ) + X_trans = trans.fit_transform(X) + assert trans.n_features_in_ == X.shape[1] + + trans.inverse_transform(X_trans) + assert trans.n_features_in_ == X.shape[1] From ee549e6f7b21810ae7faff5c3753acd64d661f40 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 6 Sep 2021 21:53:58 -0400 Subject: [PATCH 2/3] DOC Adds whats new --- doc/whats_new/v1.0.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index d9dcf8757bc68..ce9da1d4c989d 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -807,6 +807,9 @@ Changelog `n_features_in_` and will be removed in 1.2. :pr:`20240` by :user:`Jérémie du Boisberranger `. +- |Fix| :class:`preprocessing.FunctionTransformer` does not set `n_features_in_` + based on the input to `inverse_transform`. :pr:`20961` by `Thomas Fan`_. + :mod:`sklearn.svm` ................... From 029c124d59c08b230ef7616e98a85c1d3a32c868 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 7 Sep 2021 09:42:01 -0400 Subject: [PATCH 3/3] DOC Adds docstring --- sklearn/preprocessing/_function_transformer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sklearn/preprocessing/_function_transformer.py b/sklearn/preprocessing/_function_transformer.py index 9ec781196e27b..d975f63e32fe2 100644 --- a/sklearn/preprocessing/_function_transformer.py +++ b/sklearn/preprocessing/_function_transformer.py @@ -71,6 +71,20 @@ class FunctionTransformer(TransformerMixin, BaseEstimator): .. versionadded:: 0.18 + Attributes + ---------- + n_features_in_ : int + Number of features seen during :term:`fit`. Defined only when + `validate=True`. + + .. versionadded:: 0.24 + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Defined only when `validate=True` + and `X` has feature names that are all strings. + + .. versionadded:: 1.0 + See Also -------- MaxAbsScaler : Scale each feature by its maximum absolute value.