diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index a84a6ce36b218..bce411bb44e94 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -368,7 +368,7 @@ def score(self, X, y): return np.mean(np.all(y == y_pred, axis=1)) -class ClassifierChain(BaseEstimator): +class ClassifierChain(BaseEstimator, ClassifierMixin, MetaEstimatorMixin): """A multi-label model that arranges binary classifiers into a chain. Each model makes a prediction in the order specified by the chain using diff --git a/sklearn/tests/test_multioutput.py b/sklearn/tests/test_multioutput.py index 0c58d04c27581..5d5de53bbde6c 100644 --- a/sklearn/tests/test_multioutput.py +++ b/sklearn/tests/test_multioutput.py @@ -29,6 +29,7 @@ from sklearn.multioutput import MultiOutputClassifier from sklearn.multioutput import MultiOutputRegressor from sklearn.svm import LinearSVC +from sklearn.base import ClassifierMixin from sklearn.utils import shuffle @@ -380,6 +381,8 @@ def test_classifier_chain_fit_and_predict_with_logistic_regression(): assert_equal([c.coef_.size for c in classifier_chain.estimators_], list(range(X.shape[1], X.shape[1] + Y.shape[1]))) + assert isinstance(classifier_chain, ClassifierMixin) + def test_classifier_chain_fit_and_predict_with_linear_svc(): # Fit classifier chain and verify predict performance using LinearSVC