Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1d487fb

Browse files
facaiyMechCoder
authored andcommitted
[MRG+1] issue #6532 Add inverse_transform function (#6570)
* [MRG+1] #6532 Add inverse_func argument to FunctionTransformer * modify test:inverse_func is not true inverse
1 parent 45ff64b commit 1d487fb

2 files changed

Lines changed: 35 additions & 5 deletions

File tree

sklearn/preprocessing/_function_transformer.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ class FunctionTransformer(BaseEstimator, TransformerMixin):
3030
the same arguments as transform, with args and kwargs forwarded.
3131
If func is None, then func will be the identity function.
3232
33+
inverse_func : callable, optional default=None
34+
The callable to use for the inverse transformation. This will be
35+
passed the same arguments as inverse transform, with args and
36+
kwargs forwarded. If inverse_func is None, then inverse_func
37+
will be the identity function.
38+
3339
validate : bool, optional default=True
3440
Indicate that the input X array should be checked before calling
3541
func. If validate is false, there will be no input validation.
@@ -49,26 +55,38 @@ class FunctionTransformer(BaseEstimator, TransformerMixin):
4955
kw_args : dict, optional
5056
Dictionary of additional keyword arguments to pass to func.
5157
58+
inv_kw_args : dict, optional
59+
Dictionary of additional keyword arguments to pass to inverse_func.
60+
5261
"""
53-
def __init__(self, func=None, validate=True,
62+
def __init__(self, func=None, inverse_func=None, validate=True,
5463
accept_sparse=False, pass_y=False,
55-
kw_args=None):
64+
kw_args=None, inv_kw_args=None):
5665
self.func = func
66+
self.inverse_func = inverse_func
5767
self.validate = validate
5868
self.accept_sparse = accept_sparse
5969
self.pass_y = pass_y
6070
self.kw_args = kw_args
71+
self.inv_kw_args = inv_kw_args
6172

6273
def fit(self, X, y=None):
6374
if self.validate:
6475
check_array(X, self.accept_sparse)
6576
return self
6677

6778
def transform(self, X, y=None):
79+
return self._transform(X, y, self.func, self.kw_args)
80+
81+
def inverse_transform(self, X, y=None):
82+
return self._transform(X, y, self.inverse_func, self.inv_kw_args)
83+
84+
def _transform(self, X, y=None, func=None, kw_args=None):
6885
if self.validate:
6986
X = check_array(X, self.accept_sparse)
70-
func = self.func if self.func is not None else _identity
7187

88+
if func is None:
89+
func = _identity
7290

7391
return func(X, *((y,) if self.pass_y else ()),
74-
**(self.kw_args if self.kw_args else {}))
92+
**(kw_args if kw_args else {}))

sklearn/preprocessing/tests/test_function_transformer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,16 @@ def test_kw_arg_reset():
115115

116116
# Test that rounding is correct
117117
testing.assert_array_equal(F.transform(X),
118-
np.around(X, decimals=1))
118+
np.around(X, decimals=1))
119+
120+
121+
def test_inverse_transform():
122+
X = np.array([1, 4, 9, 16]).reshape((2, 2))
123+
124+
# Test that inverse_transform works correctly
125+
F = FunctionTransformer(
126+
func=np.sqrt,
127+
inverse_func=np.around, inv_kw_args=dict(decimals=3))
128+
testing.assert_array_equal(
129+
F.inverse_transform(F.transform(X)),
130+
np.around(np.sqrt(X), decimals=3))

0 commit comments

Comments
 (0)