Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c42edd3

Browse files
thomasjpfanadrinjalali
authored andcommitted
BUG Fixes FunctionTransformer validation in inverse_transform (#20961)
1 parent c702e54 commit c42edd3

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

doc/whats_new/v1.0.rst

+3
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,9 @@ Changelog
824824
`n_features_in_` and will be removed in 1.2. :pr:`20240` by
825825
:user:`Jérémie du Boisberranger <jeremiedbb>`.
826826

827+
- |Fix| :class:`preprocessing.FunctionTransformer` does not set `n_features_in_`
828+
based on the input to `inverse_transform`. :pr:`20961` by `Thomas Fan`_.
829+
827830
:mod:`sklearn.svm`
828831
...................
829832

sklearn/preprocessing/_function_transformer.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22

33
from ..base import BaseEstimator, TransformerMixin
4-
from ..utils.validation import _allclose_dense_sparse
4+
from ..utils.validation import _allclose_dense_sparse, check_array
55

66

77
def _identity(X):
@@ -71,6 +71,20 @@ class FunctionTransformer(TransformerMixin, BaseEstimator):
7171
7272
.. versionadded:: 0.18
7373
74+
Attributes
75+
----------
76+
n_features_in_ : int
77+
Number of features seen during :term:`fit`. Defined only when
78+
`validate=True`.
79+
80+
.. versionadded:: 0.24
81+
82+
feature_names_in_ : ndarray of shape (`n_features_in_`,)
83+
Names of features seen during :term:`fit`. Defined only when `validate=True`
84+
and `X` has feature names that are all strings.
85+
86+
.. versionadded:: 1.0
87+
7488
See Also
7589
--------
7690
MaxAbsScaler : Scale each feature by its maximum absolute value.
@@ -110,9 +124,9 @@ def __init__(
110124
self.kw_args = kw_args
111125
self.inv_kw_args = inv_kw_args
112126

113-
def _check_input(self, X):
127+
def _check_input(self, X, *, reset):
114128
if self.validate:
115-
return self._validate_data(X, accept_sparse=self.accept_sparse)
129+
return self._validate_data(X, accept_sparse=self.accept_sparse, reset=reset)
116130
return X
117131

118132
def _check_inverse_transform(self, X):
@@ -146,7 +160,7 @@ def fit(self, X, y=None):
146160
self : object
147161
FunctionTransformer class instance.
148162
"""
149-
X = self._check_input(X)
163+
X = self._check_input(X, reset=True)
150164
if self.check_inverse and not (self.func is None or self.inverse_func is None):
151165
self._check_inverse_transform(X)
152166
return self
@@ -164,6 +178,7 @@ def transform(self, X):
164178
X_out : array-like, shape (n_samples, n_features)
165179
Transformed input.
166180
"""
181+
X = self._check_input(X, reset=False)
167182
return self._transform(X, func=self.func, kw_args=self.kw_args)
168183

169184
def inverse_transform(self, X):
@@ -179,11 +194,11 @@ def inverse_transform(self, X):
179194
X_out : array-like, shape (n_samples, n_features)
180195
Transformed input.
181196
"""
197+
if self.validate:
198+
X = check_array(X, accept_sparse=self.accept_sparse)
182199
return self._transform(X, func=self.inverse_func, kw_args=self.inv_kw_args)
183200

184201
def _transform(self, X, func=None, kw_args=None):
185-
X = self._check_input(X)
186-
187202
if func is None:
188203
func = _identity
189204

sklearn/preprocessing/tests/test_function_transformer.py

+24
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,27 @@ def test_function_transformer_frame():
174174
transformer = FunctionTransformer()
175175
X_df_trans = transformer.fit_transform(X_df)
176176
assert hasattr(X_df_trans, "loc")
177+
178+
179+
def test_function_transformer_validate_inverse():
180+
"""Test that function transformer does not reset estimator in
181+
`inverse_transform`."""
182+
183+
def add_constant_feature(X):
184+
X_one = np.ones((X.shape[0], 1))
185+
return np.concatenate((X, X_one), axis=1)
186+
187+
def inverse_add_constant(X):
188+
return X[:, :-1]
189+
190+
X = np.array([[1, 2], [3, 4], [3, 4]])
191+
trans = FunctionTransformer(
192+
func=add_constant_feature,
193+
inverse_func=inverse_add_constant,
194+
validate=True,
195+
)
196+
X_trans = trans.fit_transform(X)
197+
assert trans.n_features_in_ == X.shape[1]
198+
199+
trans.inverse_transform(X_trans)
200+
assert trans.n_features_in_ == X.shape[1]

0 commit comments

Comments
 (0)