-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+1] ENH: Adds FunctionTransformer #4798
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
190caaf
69f5d66
67ddf96
ec7ddcb
0b4b880
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
""" | ||
========================================================= | ||
Using FunctionTransformer to select columns | ||
========================================================= | ||
|
||
Shows how to use a function transformer in a pipeline. If you know your | ||
dataset's first principle component is irrelevant for a classification task, | ||
you can use the FunctionTransformer to select all but the first column of the | ||
PCA transformed data. | ||
""" | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from sklearn.cross_validation import train_test_split | ||
from sklearn.decomposition import PCA | ||
from sklearn.pipeline import make_pipeline | ||
from sklearn.preprocessing import FunctionTransformer | ||
|
||
|
||
def _generate_vector(shift=0.5, noise=15): | ||
return np.arange(1000) + (np.random.rand(1000) - shift) * noise | ||
|
||
|
||
def generate_dataset(): | ||
""" | ||
This dataset is two lines with a slope ~ 1, where one has | ||
a y offset of ~100 | ||
""" | ||
return np.vstack(( | ||
np.vstack(( | ||
_generate_vector(), | ||
_generate_vector() + 100, | ||
)).T, | ||
np.vstack(( | ||
_generate_vector(), | ||
_generate_vector(), | ||
)).T, | ||
)), np.hstack((np.zeros(1000), np.ones(1000))) | ||
|
||
|
||
def all_but_first_column(X): | ||
return X[:, 1:] | ||
|
||
|
||
def drop_first_component(X, y): | ||
""" | ||
Create a pipeline with PCA and the column selector and use it to | ||
transform the dataset. | ||
""" | ||
pipeline = make_pipeline( | ||
PCA(), FunctionTransformer(all_but_first_column), | ||
) | ||
X_train, X_test, y_train, y_test = train_test_split(X, y) | ||
pipeline.fit(X_train, y_train) | ||
return pipeline.transform(X_test), y_test | ||
|
||
|
||
if __name__ == '__main__': | ||
X, y = generate_dataset() | ||
plt.scatter(X[:, 0], X[:, 1], c=y, s=50) | ||
plt.show() | ||
X_transformed, y_transformed = drop_first_component(*generate_dataset()) | ||
plt.scatter( | ||
X_transformed[:, 0], | ||
np.zeros(len(X_transformed)), | ||
c=y_transformed, | ||
s=50, | ||
) | ||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from ..base import BaseEstimator, TransformerMixin | ||
from ..utils import check_array | ||
|
||
|
||
def _identity(X): | ||
"""The identity function. | ||
""" | ||
return X | ||
|
||
|
||
class FunctionTransformer(BaseEstimator, TransformerMixin): | ||
"""Constructs a transformer from an arbitrary callable. | ||
|
||
A FunctionTransformer forwards its X (and optionally y) arguments to a | ||
user-defined function or function object and returns the result of this | ||
function. This is useful for stateless transformations such as taking the | ||
log of frequencies, doing custom scaling, etc. | ||
|
||
A FunctionTransformer will not do any checks on its function's output. | ||
|
||
Note: If a lambda is used as the function, then the resulting | ||
transformer will not be pickleable. | ||
|
||
Parameters | ||
---------- | ||
func : callable, optional default=None | ||
The callable to use for the transformation. This will be passed | ||
the same arguments as transform, with args and kwargs forwarded. | ||
If func is None, then func will be the identity function. | ||
|
||
validate : bool, optional default=True | ||
Indicate that the input X array should be checked before calling | ||
func. If validate is false, there will be no input validation. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that this will ensure the input is a non-empty, 2-dimensional array (or sparse matrix) of finite numbers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
If it is true, then X will be converted to a 2-dimensional NumPy | ||
array or sparse matrix. If this conversion is not possible or X | ||
contains NaN or infinity, an exception is raised. | ||
|
||
accept_sparse : boolean, optional | ||
Indicate that func accepts a sparse matrix as input. If validate is | ||
False, this has no effect. Otherwise, if accept_sparse is false, | ||
sparse matrix inputs will cause an exception to be raised. | ||
|
||
pass_y: bool, optional default=False | ||
Indicate that transform should forward the y argument to the | ||
inner callable. | ||
|
||
""" | ||
def __init__(self, func=None, validate=True, | ||
accept_sparse=False, pass_y=False): | ||
self.func = func | ||
self.validate = validate | ||
self.accept_sparse = accept_sparse | ||
self.pass_y = pass_y | ||
|
||
def fit(self, X, y=None): | ||
if self.validate: | ||
check_array(X, self.accept_sparse) | ||
return self | ||
|
||
def transform(self, X, y=None): | ||
if self.validate: | ||
X = check_array(X, self.accept_sparse) | ||
|
||
return (self.func or _identity)(X, *((y,) if self.pass_y else ())) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from nose.tools import assert_equal | ||
import numpy as np | ||
|
||
from ..function_transformer import FunctionTransformer | ||
|
||
|
||
def _make_func(args_store, kwargs_store, func=lambda X, *a, **k: X): | ||
def _func(X, *args, **kwargs): | ||
args_store.append(X) | ||
args_store.extend(args) | ||
kwargs_store.update(kwargs) | ||
return func(X) | ||
|
||
return _func | ||
|
||
|
||
def test_delegate_to_func(): | ||
# (args|kwargs)_store will hold the positional and keyword arguments | ||
# passed to the function inside the FunctionTransformer. | ||
args_store = [] | ||
kwargs_store = {} | ||
X = np.arange(10).reshape((5, 2)) | ||
np.testing.assert_array_equal( | ||
FunctionTransformer(_make_func(args_store, kwargs_store)).transform(X), | ||
X, | ||
'transform should have returned X unchanged', | ||
) | ||
|
||
# The function should only have recieved X. | ||
assert_equal( | ||
args_store, | ||
[X], | ||
'Incorrect positional arguments passed to func: {args}'.format( | ||
args=args_store, | ||
), | ||
) | ||
assert_equal( | ||
kwargs_store, | ||
{}, | ||
'Unexpected keyword arguments passed to func: {args}'.format( | ||
args=kwargs_store, | ||
), | ||
) | ||
|
||
# reset the argument stores. | ||
args_store[:] = [] # python2 compatible inplace list clear. | ||
kwargs_store.clear() | ||
y = object() | ||
|
||
np.testing.assert_array_equal( | ||
FunctionTransformer( | ||
_make_func(args_store, kwargs_store), | ||
pass_y=True, | ||
).transform(X, y), | ||
X, | ||
'transform should have returned X unchanged', | ||
) | ||
|
||
# The function should have recieved X and y. | ||
assert_equal( | ||
args_store, | ||
[X, y], | ||
'Incorrect positional arguments passed to func: {args}'.format( | ||
args=args_store, | ||
), | ||
) | ||
assert_equal( | ||
kwargs_store, | ||
{}, | ||
'Unexpected keyword arguments passed to func: {args}'.format( | ||
args=kwargs_store, | ||
), | ||
) | ||
|
||
|
||
def test_np_log(): | ||
X = np.arange(10).reshape((5, 2)) | ||
|
||
# Test that the numpy.log example still works. | ||
np.testing.assert_array_equal( | ||
FunctionTransformer(np.log).transform(X), | ||
np.log(X), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,7 +131,8 @@ def _yield_transformer_checks(name, Transformer): | |
'PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD']: | ||
yield check_transformer_data_not_an_array | ||
# these don't actually fit the data, so don't raise errors | ||
if name not in ['AdditiveChi2Sampler', 'Binarizer', 'Normalizer']: | ||
if name not in ['AdditiveChi2Sampler', 'Binarizer', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very silly comment, but these seemed alphabetically ordered, so it can't hurt to keep the order. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I think it can be removed here now, and probably AdditiveChi2Sampler can be removed, too, iirc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sklearn.tests.test_common.test_non_meta_estimators Removing these causes errors in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what check fails? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah ok, then never mind, leave it as is. |
||
'FunctionTransformer', 'Normalizer']: | ||
# basic tests | ||
yield check_transformer_general | ||
yield check_transformers_unfitted | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reading this I might be confused. "Why will I want this? I can just do X = np.log(X) or something".
I think the most valuable use case is to grid search over the transformer function used. The paragraph might read better with a mention of this scenario.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, most people don't know what a callable is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mblondel would you still stick with the name then?
I agree with @vene's comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't agree that it's most valuable with a search over the function. But I think it has no value except in a
Pipeline
orFeatureUnion
and it should be illustrated in that context.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops that is what I meant. It is useful for pipelines. I don't think searching over the function is the main use case. I think selecting columns / slicing etc is a good use case.