@@ -30,6 +30,12 @@ class FunctionTransformer(BaseEstimator, TransformerMixin):
3030 the same arguments as transform, with args and kwargs forwarded.
3131 If func is None, then func will be the identity function.
3232
33+ inverse_func : callable, optional default=None
34+ The callable to use for the inverse transformation. This will be
35+ passed the same arguments as inverse transform, with args and
36+ kwargs forwarded. If inverse_func is None, then inverse_func
37+ will be the identity function.
38+
3339 validate : bool, optional default=True
3440 Indicate that the input X array should be checked before calling
3541 func. If validate is false, there will be no input validation.
@@ -49,26 +55,38 @@ class FunctionTransformer(BaseEstimator, TransformerMixin):
4955 kw_args : dict, optional
5056 Dictionary of additional keyword arguments to pass to func.
5157
58+ inv_kw_args : dict, optional
59+ Dictionary of additional keyword arguments to pass to inverse_func.
60+
5261 """
53- def __init__ (self , func = None , validate = True ,
62+ def __init__ (self , func = None , inverse_func = None , validate = True ,
5463 accept_sparse = False , pass_y = False ,
55- kw_args = None ):
64+ kw_args = None , inv_kw_args = None ):
5665 self .func = func
66+ self .inverse_func = inverse_func
5767 self .validate = validate
5868 self .accept_sparse = accept_sparse
5969 self .pass_y = pass_y
6070 self .kw_args = kw_args
71+ self .inv_kw_args = inv_kw_args
6172
6273 def fit (self , X , y = None ):
6374 if self .validate :
6475 check_array (X , self .accept_sparse )
6576 return self
6677
6778 def transform (self , X , y = None ):
79+ return self ._transform (X , y , self .func , self .kw_args )
80+
81+ def inverse_transform (self , X , y = None ):
82+ return self ._transform (X , y , self .inverse_func , self .inv_kw_args )
83+
84+ def _transform (self , X , y = None , func = None , kw_args = None ):
6885 if self .validate :
6986 X = check_array (X , self .accept_sparse )
70- func = self .func if self .func is not None else _identity
7187
88+ if func is None :
89+ func = _identity
7290
7391 return func (X , * ((y ,) if self .pass_y else ()),
74- ** (self . kw_args if self . kw_args else {}))
92+ ** (kw_args if kw_args else {}))
0 commit comments