Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MAINT use _validate_params in DecisionTreeClassifier and DecisionTreeRegressor #23499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 0 additions & 68 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,74 +139,6 @@ def test_classification_toy(loss):
TypeError,
"tol must be an instance of float,",
),
# The following parameters are checked in BaseDecisionTree
({"min_samples_leaf": 0}, ValueError, "min_samples_leaf == 0, must be >= 1"),
({"min_samples_leaf": 0.0}, ValueError, "min_samples_leaf == 0.0, must be > 0"),
(
{"min_samples_leaf": "foo"},
TypeError,
"min_samples_leaf must be an instance of float",
),
({"min_samples_split": 1}, ValueError, "min_samples_split == 1, must be >= 2"),
(
{"min_samples_split": 0.0},
ValueError,
"min_samples_split == 0.0, must be > 0.0",
),
(
{"min_samples_split": 1.1},
ValueError,
"min_samples_split == 1.1, must be <= 1.0",
),
(
{"min_samples_split": "foo"},
TypeError,
"min_samples_split must be an instance of float",
),
(
{"min_weight_fraction_leaf": -1},
ValueError,
"min_weight_fraction_leaf == -1, must be >= 0.0",
),
(
{"min_weight_fraction_leaf": 0.6},
ValueError,
"min_weight_fraction_leaf == 0.6, must be <= 0.5",
),
(
{"min_weight_fraction_leaf": "foo"},
TypeError,
"min_weight_fraction_leaf must be an instance of float",
),
({"max_leaf_nodes": 0}, ValueError, "max_leaf_nodes == 0, must be >= 2"),
(
{"max_leaf_nodes": 1.5},
TypeError,
"max_leaf_nodes must be an instance of int",
),
({"max_depth": -1}, ValueError, "max_depth == -1, must be >= 1"),
(
{"max_depth": 1.1},
TypeError,
"max_depth must be an instance of int",
),
(
{"min_impurity_decrease": -1},
ValueError,
"min_impurity_decrease == -1, must be >= 0.0",
),
(
{"min_impurity_decrease": "foo"},
TypeError,
"min_impurity_decrease must be an instance of float",
),
({"ccp_alpha": -1.0}, ValueError, "ccp_alpha == -1.0, must be >= 0.0"),
(
{"ccp_alpha": "foo"},
TypeError,
"ccp_alpha must be an instance of float",
),
({"criterion": "mae"}, ValueError, "criterion='mae' is not supported."),
],
# Avoid long error messages in test names:
# https://github.com/scikit-learn/scikit-learn/issues/21362
Expand Down
4 changes: 0 additions & 4 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
"ComplementNB",
"CountVectorizer",
"DBSCAN",
"DecisionTreeClassifier",
"DecisionTreeRegressor",
"DictVectorizer",
"DictionaryLearning",
"DummyClassifier",
Expand All @@ -477,8 +475,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
"ElasticNetCV",
"EllipticEnvelope",
"EmpiricalCovariance",
"ExtraTreeClassifier",
"ExtraTreeRegressor",
"ExtraTreesClassifier",
"ExtraTreesRegressor",
"FactorAnalysis",
Expand Down
138 changes: 51 additions & 87 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from abc import ABCMeta
from abc import abstractmethod
from math import ceil
from numbers import Integral, Real

import numpy as np
from scipy.sparse import issparse
Expand All @@ -32,12 +33,12 @@
from ..base import MultiOutputMixin
from ..utils import Bunch
from ..utils import check_random_state
from ..utils import check_scalar
from ..utils.deprecation import deprecated
from ..utils.validation import _check_sample_weight
from ..utils import compute_sample_weight
from ..utils.multiclass import check_classification_targets
from ..utils.validation import check_is_fitted
from ..utils._param_validation import Hidden, Interval, StrOptions

from ._criterion import Criterion
from ._splitter import Splitter
Expand Down Expand Up @@ -97,6 +98,30 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta):
Use derived classes instead.
"""

_parameter_constraints = {
"splitter": [StrOptions({"best", "random"})],
"max_depth": [Interval(Integral, 1, None, closed="left"), None],
"min_samples_split": [
Interval(Integral, 2, None, closed="left"),
Interval(Real, 0.0, 1.0, closed="right"),
],
"min_samples_leaf": [
Interval(Integral, 1, None, closed="left"),
Interval(Real, 0.0, 1.0, closed="neither"),
],
"min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
"max_features": [
Interval(Integral, 1, None, closed="left"),
Interval(Real, 0.0, 1.0, closed="right"),
StrOptions({"auto", "sqrt", "log2"}, deprecated={"auto"}),
None,
],
"random_state": ["random_state"],
"max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None],
"min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")],
"ccp_alpha": [Interval(Real, 0.0, None, closed="left")],
}

@abstractmethod
def __init__(
self,
Expand Down Expand Up @@ -153,16 +178,9 @@ def get_n_leaves(self):
return self.tree_.n_leaves

def fit(self, X, y, sample_weight=None, check_input=True):

self._validate_params()
random_state = check_random_state(self.random_state)

check_scalar(
self.ccp_alpha,
name="ccp_alpha",
target_type=numbers.Real,
min_val=0.0,
)

if check_input:
# Need to validate separately here.
# We can't pass multi_output=True because that would allow y to be
Expand Down Expand Up @@ -233,64 +251,21 @@ def fit(self, X, y, sample_weight=None, check_input=True):
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)

# Check parameters
if self.max_depth is not None:
check_scalar(
self.max_depth,
name="max_depth",
target_type=numbers.Integral,
min_val=1,
)
max_depth = np.iinfo(np.int32).max if self.max_depth is None else self.max_depth

if isinstance(self.min_samples_leaf, numbers.Integral):
check_scalar(
self.min_samples_leaf,
name="min_samples_leaf",
target_type=numbers.Integral,
min_val=1,
)
min_samples_leaf = self.min_samples_leaf
else: # float
check_scalar(
self.min_samples_leaf,
name="min_samples_leaf",
target_type=numbers.Real,
min_val=0.0,
include_boundaries="neither",
)
min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))

if isinstance(self.min_samples_split, numbers.Integral):
check_scalar(
self.min_samples_split,
name="min_samples_split",
target_type=numbers.Integral,
min_val=2,
)
min_samples_split = self.min_samples_split
else: # float
check_scalar(
self.min_samples_split,
name="min_samples_split",
target_type=numbers.Real,
min_val=0.0,
max_val=1.0,
include_boundaries="right",
)
min_samples_split = int(ceil(self.min_samples_split * n_samples))
min_samples_split = max(2, min_samples_split)

min_samples_split = max(min_samples_split, 2 * min_samples_leaf)

check_scalar(
self.min_weight_fraction_leaf,
name="min_weight_fraction_leaf",
target_type=numbers.Real,
min_val=0.0,
max_val=0.5,
)

if isinstance(self.max_features, str):
if self.max_features == "auto":
if is_classification:
Expand All @@ -313,55 +288,20 @@ def fit(self, X, y, sample_weight=None, check_input=True):
max_features = max(1, int(np.sqrt(self.n_features_in_)))
elif self.max_features == "log2":
max_features = max(1, int(np.log2(self.n_features_in_)))
else:
raise ValueError(
"Invalid value for max_features. "
"Allowed string values are 'auto', "
"'sqrt' or 'log2'."
)
elif self.max_features is None:
max_features = self.n_features_in_
elif isinstance(self.max_features, numbers.Integral):
check_scalar(
self.max_features,
name="max_features",
target_type=numbers.Integral,
min_val=1,
include_boundaries="left",
)
max_features = self.max_features
else: # float
check_scalar(
self.max_features,
name="max_features",
target_type=numbers.Real,
min_val=0.0,
max_val=1.0,
include_boundaries="right",
)
if self.max_features > 0.0:
max_features = max(1, int(self.max_features * self.n_features_in_))
else:
max_features = 0

self.max_features_ = max_features

if self.max_leaf_nodes is not None:
check_scalar(
self.max_leaf_nodes,
name="max_leaf_nodes",
target_type=numbers.Integral,
min_val=2,
)
max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes

check_scalar(
self.min_impurity_decrease,
name="min_impurity_decrease",
target_type=numbers.Real,
min_val=0.0,
)

if len(y) != n_samples:
raise ValueError(
"Number of labels=%d does not match number of samples=%d"
Expand Down Expand Up @@ -905,6 +845,12 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree):
0.93..., 0.93..., 1. , 0.93..., 1. ])
"""

_parameter_constraints = {
**BaseDecisionTree._parameter_constraints,
"criterion": [StrOptions({"gini", "entropy", "log_loss"}), Hidden(Criterion)],
"class_weight": [dict, list, StrOptions({"balanced"}), None],
}

def __init__(
self,
*,
Expand Down Expand Up @@ -1281,6 +1227,24 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
0.16..., 0.11..., -0.73..., -0.30..., -0.00...])
"""

_parameter_constraints = {
**BaseDecisionTree._parameter_constraints,
"criterion": [
StrOptions(
{
"squared_error",
"friedman_mse",
"absolute_error",
"poisson",
"mse",
"mae",
},
deprecated={"mse", "mae"},
),
Hidden(Criterion),
],
}

def __init__(
self,
*,
Expand Down
Loading