diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 9e9c2bba016e6..d8fe5f494a64e 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -7,6 +7,7 @@ # Virgile Fritsch # Alexandre Gramfort # Lars Buitinck +# Tim Head # Licence: BSD from collections import defaultdict @@ -15,14 +16,16 @@ import numpy as np from scipy import sparse -from .base import BaseEstimator, TransformerMixin +from .base import BaseEstimator, ClassifierMixin +from .base import MetaEstimatorMixin, TransformerMixin from .externals.joblib import Parallel, delayed from .externals import six from .utils import tosequence from .utils.metaestimators import if_delegate_has_method from .externals.six import iteritems -__all__ = ['Pipeline', 'FeatureUnion'] +__all__ = ['Pipeline', 'FeatureUnion', 'PredictionTransformer', + 'ThresholdClassifier'] class Pipeline(BaseEstimator): @@ -576,3 +579,32 @@ def make_union(*transformers): f : FeatureUnion """ return FeatureUnion(_name_estimators(transformers)) + + +class PredictionTransformer(BaseEstimator, TransformerMixin, MetaEstimatorMixin): + def __init__(self, clf): + """Replaces all features with `clf.predict_proba(X)`""" + self.clf = clf + + def fit(self, X, y): + self.clf.fit(X, y) + return self + + def transform(self, X): + return self.clf.predict_proba(X) + + +class ThresholdClassifier(BaseEstimator, ClassifierMixin): + def __init__(self, threshold=0.5): + """Classify samples based on whether they are above of below `threshold`""" + self.threshold = threshold + + def fit(self, X, y): + self.classes_ = np.unique(y) + return self + + def predict(self, X): + # the implementation used here breaks ties differently + # from the one used in RFs: + #return self.classes_.take(np.argmax(X, axis=1), axis=0) + return np.where(X[:, 0]>self.threshold, *self.classes_) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index a8c5fff4efe8f..d1574697ca282 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -17,13 +17,15 @@ from sklearn.base import clone from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union +from sklearn.pipeline import PredictionTransformer, ThresholdClassifier from sklearn.svm import SVC from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LinearRegression from sklearn.cluster import KMeans from sklearn.feature_selection import SelectKBest, f_classif from sklearn.decomposition import PCA, TruncatedSVD -from sklearn.datasets import load_iris +from sklearn.ensemble import RandomForestClassifier +from sklearn.datasets import load_iris, make_classification from sklearn.preprocessing import StandardScaler from sklearn.feature_extraction.text import CountVectorizer @@ -472,3 +474,16 @@ def test_X1d_inverse_transform(): X = np.ones(10) msg = "1d X will not be reshaped in pipeline.inverse_transform" assert_warns_message(FutureWarning, msg, pipeline.inverse_transform, X) + + +def test_prediction_transformer_pipeline(): + X, y = make_classification() + + pipe = make_pipeline(PredictionTransformer(RandomForestClassifier()), + ThresholdClassifier()) + pipe.fit(X, y) + + clf = RandomForestClassifier() + clf.fit(X, y) + + assert_array_equal(clf.predict(X), pipe.predict(X))