Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ae5a22f

Browse files
thomasjpfanjeremiedbb
authored andcommitted
FIX Handles all numerical DataFrames with check_inverse=True in FunctionTransformer (scikit-learn#25274)
Co-authored-by: Jérémie du Boisberranger <[email protected]>
1 parent f58c08a commit ae5a22f

File tree

3 files changed

+44
-1
lines changed

3 files changed

+44
-1
lines changed

doc/whats_new/v1.2.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ Changelog
2727
`verbose` parameter set to a value greater than 0.
2828
:pr:`25250` by :user:`Jérémie Du Boisberranger <jeremiedbb>`.
2929

30+
:mod:`sklearn.preprocessing`
31+
............................
32+
33+
- |Fix| :meth:`preprocessing.FunctionTransformer.inverse_transform` correctly
34+
supports DataFrames that are all numerical when `check_inverse=True`.
35+
:pr:`25274` by `Thomas Fan`_.
36+
3037
:mod:`sklearn.utils`
3138
....................
3239

sklearn/preprocessing/_function_transformer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,13 @@ def _check_inverse_transform(self, X):
174174
idx_selected = slice(None, None, max(1, X.shape[0] // 100))
175175
X_round_trip = self.inverse_transform(self.transform(X[idx_selected]))
176176

177-
if not np.issubdtype(X.dtype, np.number):
177+
if hasattr(X, "dtype"):
178+
dtypes = [X.dtype]
179+
elif hasattr(X, "dtypes"):
180+
# Dataframes can have multiple dtypes
181+
dtypes = X.dtypes
182+
183+
if not all(np.issubdtype(d, np.number) for d in dtypes):
178184
raise ValueError(
179185
"'check_inverse' is only supported when all the elements in `X` is"
180186
" numerical."

sklearn/preprocessing/tests/test_function_transformer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,36 @@ def inverse_func(X):
217217
transformer.fit(data)
218218

219219

220+
def test_function_transformer_support_all_nummerical_dataframes_check_inverse_True():
221+
"""Check support for dataframes with only numerical values."""
222+
pd = pytest.importorskip("pandas")
223+
224+
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
225+
transformer = FunctionTransformer(
226+
func=lambda x: x + 2, inverse_func=lambda x: x - 2, check_inverse=True
227+
)
228+
229+
# Does not raise an error
230+
df_out = transformer.fit_transform(df)
231+
assert_allclose_dense_sparse(df_out, df + 2)
232+
233+
234+
def test_function_transformer_with_dataframe_and_check_inverse_True():
235+
"""Check error is raised when check_inverse=True.
236+
237+
Non-regresion test for gh-25261.
238+
"""
239+
pd = pytest.importorskip("pandas")
240+
transformer = FunctionTransformer(
241+
func=lambda x: x, inverse_func=lambda x: x, check_inverse=True
242+
)
243+
244+
df_mixed = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
245+
msg = "'check_inverse' is only supported when all the elements in `X` is numerical."
246+
with pytest.raises(ValueError, match=msg):
247+
transformer.fit(df_mixed)
248+
249+
220250
@pytest.mark.parametrize(
221251
"X, feature_names_out, input_features, expected",
222252
[

0 commit comments

Comments
 (0)