diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index c0b1192eafc32..5173d286967c6 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -32,6 +32,7 @@ from ..utils import Bunch from ..utils import check_array from ..utils import check_random_state +from ..utils.validation import _check_sample_weight from ..utils import compute_sample_weight from ..utils.multiclass import check_classification_targets from ..utils.validation import check_is_fitted @@ -266,18 +267,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, "or larger than 1").format(max_leaf_nodes)) if sample_weight is not None: - if (getattr(sample_weight, "dtype", None) != DOUBLE or - not sample_weight.flags.contiguous): - sample_weight = np.ascontiguousarray( - sample_weight, dtype=DOUBLE) - if len(sample_weight.shape) > 1: - raise ValueError("Sample weights array has more " - "than one dimension: %d" % - len(sample_weight.shape)) - if len(sample_weight) != n_samples: - raise ValueError("Number of weights=%d does not match " - "number of samples=%d" % - (len(sample_weight), n_samples)) + sample_weight = _check_sample_weight(sample_weight, X, DOUBLE) if expanded_class_weight is not None: if sample_weight is not None: diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 4d13f23818e2b..d4ce87e5cfba6 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -3,7 +3,6 @@ """ import copy import pickle -from functools import partial from itertools import product import struct @@ -1121,7 +1120,8 @@ def test_sample_weight_invalid(): clf.fit(X, y, sample_weight=sample_weight) sample_weight = np.array(0) - with pytest.raises(ValueError): + expected_err = r"Singleton.* cannot be considered a valid collection" + with pytest.raises(TypeError, match=expected_err): clf.fit(X, y, sample_weight=sample_weight) sample_weight = np.ones(101)