From 1322f1ce32d033fe4e44f442073281564f67cfed Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Wed, 2 Aug 2017 09:21:43 +1000 Subject: [PATCH 1/2] Add missing mixins to ClassifierChain --- sklearn/multioutput.py | 2 +- sklearn/tests/test_multioutput.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index a84a6ce36b218..bce411bb44e94 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -368,7 +368,7 @@ def score(self, X, y): return np.mean(np.all(y == y_pred, axis=1)) -class ClassifierChain(BaseEstimator): +class ClassifierChain(BaseEstimator, ClassifierMixin, MetaEstimatorMixin): """A multi-label model that arranges binary classifiers into a chain. Each model makes a prediction in the order specified by the chain using diff --git a/sklearn/tests/test_multioutput.py b/sklearn/tests/test_multioutput.py index 0c58d04c27581..4149d50a400f3 100644 --- a/sklearn/tests/test_multioutput.py +++ b/sklearn/tests/test_multioutput.py @@ -29,6 +29,7 @@ from sklearn.multioutput import MultiOutputClassifier from sklearn.multioutput import MultiOutputRegressor from sklearn.svm import LinearSVC +from .base import ClassifierMixin from sklearn.utils import shuffle @@ -380,6 +381,8 @@ def test_classifier_chain_fit_and_predict_with_logistic_regression(): assert_equal([c.coef_.size for c in classifier_chain.estimators_], list(range(X.shape[1], X.shape[1] + Y.shape[1]))) + assert isinstance(classifier_chain, ClassifierMixin) + def test_classifier_chain_fit_and_predict_with_linear_svc(): # Fit classifier chain and verify predict performance using LinearSVC From 0377924c4d11230ef06dd715bd61ff6a9b39e53c Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Wed, 2 Aug 2017 10:20:03 +1000 Subject: [PATCH 2/2] Fix import in test --- sklearn/tests/test_multioutput.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_multioutput.py b/sklearn/tests/test_multioutput.py index 4149d50a400f3..5d5de53bbde6c 100644 --- a/sklearn/tests/test_multioutput.py +++ b/sklearn/tests/test_multioutput.py @@ -29,7 +29,7 @@ from sklearn.multioutput import MultiOutputClassifier from sklearn.multioutput import MultiOutputRegressor from sklearn.svm import LinearSVC -from .base import ClassifierMixin +from sklearn.base import ClassifierMixin from sklearn.utils import shuffle