diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index acadaf6cadf5b..da647aba9b1e5 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1376,3 +1376,15 @@ def check_decision_path(name): def test_decision_path(): for name in ALL_TREES: yield (check_decision_path, name) + + +def check_no_sparse_y_support(name): + X, y = X_multilabel, csr_matrix(y_multilabel) + TreeEstimator = ALL_TREES[name] + assert_raises(ValueError, TreeEstimator(random_state=0).fit, X, y) + + +def test_no_sparse_y_support(): + # Currently we don't support sparse y + for name in ALL_TREES: + yield (check_decision_path, name) diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 028d2626bc23c..363c0220f9f4a 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -151,7 +151,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, random_state = check_random_state(self.random_state) if check_input: X = check_array(X, dtype=DTYPE, accept_sparse="csc") - y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None) + y = check_array(y, ensure_2d=False, dtype=None) if issparse(X): X.sort_indices()