Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MAINT Use check_scalar in BaseDecisionTree #21990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
47997ee
max_depth: add tests
genvalen Dec 15, 2021
02fc4af
max_depth: add validation
genvalen Dec 15, 2021
0ab5535
min_sample_split: add tests and validation
genvalen Dec 15, 2021
88b13d6
min_impurity_decrease: add tests and valiation
genvalen Dec 15, 2021
645fc22
fix lint issue
genvalen Dec 15, 2021
69c4347
max_leaf_nodes: add tests and validation
genvalen Dec 16, 2021
a2c8c4b
min_weight_fraction_leaf: add tests and validation
genvalen Dec 17, 2021
4100bbd
format files w/ black
genvalen Dec 17, 2021
ec647de
remove comments
genvalen Dec 17, 2021
9bc6582
min_samples_leaf: add tests and validation
genvalen Dec 28, 2021
c850f6c
max_features: add tests and validation
genvalen Dec 28, 2021
4fe47d0
min_samples_split: update
genvalen Dec 28, 2021
44c5afc
max_leaf_nodes: remove redundant tests
genvalen Dec 28, 2021
12eb88a
remove redundant tests
genvalen Dec 28, 2021
664aa88
Merge branch 'BaseDecisionTree_add_check_scalar' of https://github.co…
genvalen Dec 28, 2021
60a215f
Fix lint issue
genvalen Dec 28, 2021
e845597
update range for min_weight_fraction_leaf
genvalen Dec 28, 2021
9bdb404
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
genvalen Dec 28, 2021
6eb026c
remove redundant tests
genvalen Dec 30, 2021
b9a5091
min_samples_split: update to include upper bound
genvalen Dec 30, 2021
eb8bacb
slight style edit to tests
genvalen Dec 30, 2021
1060fd7
update boundary inclusivity for max_features
genvalen Dec 30, 2021
f5973f8
update order of tests
genvalen Dec 30, 2021
5c5769b
update order of validations to match signature more closely
genvalen Dec 30, 2021
2e4528d
remove some boundary args to rely on defaults more and reduce lines
genvalen Dec 30, 2021
dfb4b87
ccp_alpha: add tests and validation
genvalen Jan 4, 2022
fa7a486
edit comment
genvalen Jan 4, 2022
4267678
Update sklearn/tree/_classes.py
genvalen Jan 4, 2022
9d1bed6
min_samples_split: update tests
genvalen Jan 4, 2022
63a5a93
ccp_alpha: remove redundant tests
genvalen Jan 4, 2022
5cf4fa3
update check scalar calls to explicitly reference "name" param
genvalen Jan 8, 2022
0b89587
Update sklearn/tree/_classes.py
genvalen Jan 11, 2022
a61e73c
update tests for min_samples_split
genvalen Jan 11, 2022
51e0980
update messages in test_gbdt_parameter_checks to pass CI
genvalen Jan 12, 2022
5eda80a
Update sklearn/ensemble/tests/test_gradient_boosting.py
genvalen Jan 12, 2022
a65e9bf
Put test_max_features back on line 504
genvalen Jan 12, 2022
5e5b381
fix lint issue
genvalen Jan 12, 2022
3a424c2
Update sklearn/tree/_classes.py
genvalen Jan 13, 2022
521a35c
Revert "update tests for min_samples_split"
genvalen Jan 13, 2022
ce3f22d
update min_val for min_samples_split
genvalen Jan 13, 2022
9357c87
Update sklearn/tree/_classes.py
genvalen Jan 24, 2022
733a3c8
Update sklearn/tree/_classes.py
genvalen Jan 24, 2022
9167c47
update tests
genvalen Jan 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 60 additions & 27 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,31 +79,63 @@ def test_classification_toy(loss):


@pytest.mark.parametrize(
"params, err_msg",
"params, err_type, err_msg",
[
({"n_estimators": 0}, "n_estimators must be greater than 0"),
({"n_estimators": -1}, "n_estimators must be greater than 0"),
({"learning_rate": 0}, "learning_rate must be greater than 0"),
({"learning_rate": -1.0}, "learning_rate must be greater than 0"),
({"loss": "foobar"}, "Loss 'foobar' not supported"),
({"min_samples_split": 0.0}, "min_samples_split must be an integer"),
({"min_samples_split": -1.0}, "min_samples_split must be an integer"),
({"min_samples_split": 1.1}, "min_samples_split must be an integer"),
({"min_samples_leaf": 0}, "min_samples_leaf must be at least 1 or"),
({"min_samples_leaf": -1.0}, "min_samples_leaf must be at least 1 or"),
({"min_weight_fraction_leaf": -1.0}, "min_weight_fraction_leaf must in"),
({"min_weight_fraction_leaf": 0.6}, "min_weight_fraction_leaf must in"),
({"subsample": 0.0}, r"subsample must be in \(0,1\]"),
({"subsample": 1.1}, r"subsample must be in \(0,1\]"),
({"subsample": -0.1}, r"subsample must be in \(0,1\]"),
({"max_depth": -0.1}, "max_depth must be greater than zero"),
({"max_depth": 0}, "max_depth must be greater than zero"),
({"init": {}}, "The init parameter must be an estimator or 'zero'"),
({"max_features": "invalid"}, "Invalid value for max_features:"),
({"max_features": 0}, r"max_features must be in \(0, n_features\]"),
({"max_features": 100}, r"max_features must be in \(0, n_features\]"),
({"max_features": -0.1}, r"max_features must be in \(0, n_features\]"),
({"n_iter_no_change": "invalid"}, "n_iter_no_change should either be"),
({"n_estimators": 0}, ValueError, "n_estimators must be greater than 0"),
({"n_estimators": -1}, ValueError, "n_estimators must be greater than 0"),
({"learning_rate": 0}, ValueError, "learning_rate must be greater than 0"),
({"learning_rate": -1.0}, ValueError, "learning_rate must be greater than 0"),
({"loss": "foobar"}, ValueError, "Loss 'foobar' not supported"),
(
{"min_samples_split": 0.0},
ValueError,
"min_samples_split == 0.0, must be > 0.0",
),
(
{"min_samples_split": -1.0},
ValueError,
"min_samples_split == -1.0, must be > 0.0",
),
(
{"min_samples_split": 1.1},
ValueError,
"min_samples_split == 1.1, must be <= 1.0.",
),
({"min_samples_leaf": 0}, ValueError, "min_samples_leaf == 0, must be >= 1"),
(
{"min_samples_leaf": -1.0},
ValueError,
"min_samples_leaf == -1.0, must be > 0.0.",
),
(
{"min_weight_fraction_leaf": -1.0},
ValueError,
"min_weight_fraction_leaf == -1.0, must be >= 0",
),
(
{"min_weight_fraction_leaf": 0.6},
ValueError,
"min_weight_fraction_leaf == 0.6, must be <= 0.5.",
),
({"subsample": 0.0}, ValueError, r"subsample must be in \(0,1\]"),
({"subsample": 1.1}, ValueError, r"subsample must be in \(0,1\]"),
({"subsample": -0.1}, ValueError, r"subsample must be in \(0,1\]"),
({"max_depth": -0.1}, TypeError, "max_depth must be an instance of"),
({"max_depth": 0}, ValueError, "max_depth == 0, must be >= 1."),
({"init": {}}, ValueError, "The init parameter must be an estimator or 'zero'"),
({"max_features": "invalid"}, ValueError, "Invalid value for max_features:"),
({"max_features": 0}, ValueError, "max_features == 0, must be >= 1"),
({"max_features": 100}, ValueError, "max_features == 100, must be <="),
(
{"max_features": -0.1},
ValueError,
r"max_features must be in \(0, n_features\]",
),
(
{"n_iter_no_change": "invalid"},
ValueError,
"n_iter_no_change should either be",
),
],
# Avoid long error messages in test names:
# https://github.com/scikit-learn/scikit-learn/issues/21362
Expand All @@ -116,10 +148,11 @@ def test_classification_toy(loss):
(GradientBoostingClassifier, iris.data, iris.target),
],
)
def test_gbdt_parameter_checks(GradientBoosting, X, y, params, err_msg):
def test_gbdt_parameter_checks(GradientBoosting, X, y, params, err_type, err_msg):
# Check input parameter validation for GradientBoosting
with pytest.raises(ValueError, match=err_msg):
GradientBoosting(**params).fit(X, y)
est = GradientBoosting(**params)
with pytest.raises(err_type, match=err_msg):
est.fit(X, y)


@pytest.mark.parametrize(
Expand Down
129 changes: 80 additions & 49 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..base import MultiOutputMixin
from ..utils import Bunch
from ..utils import check_random_state
from ..utils import check_scalar
from ..utils.deprecation import deprecated
from ..utils.validation import _check_sample_weight
from ..utils import compute_sample_weight
Expand Down Expand Up @@ -151,8 +152,12 @@ def fit(self, X, y, sample_weight=None, check_input=True):

random_state = check_random_state(self.random_state)

if self.ccp_alpha < 0.0:
raise ValueError("ccp_alpha must be greater than or equal to 0")
check_scalar(
self.ccp_alpha,
name="ccp_alpha",
target_type=numbers.Real,
min_val=0.0,
)

if check_input:
# Need to validate separately here.
Expand Down Expand Up @@ -225,46 +230,63 @@ def fit(self, X, y, sample_weight=None, check_input=True):
y = np.ascontiguousarray(y, dtype=DOUBLE)

# Check parameters
if self.max_depth is not None:
check_scalar(
self.max_depth,
name="max_depth",
target_type=numbers.Integral,
min_val=1,
)
max_depth = np.iinfo(np.int32).max if self.max_depth is None else self.max_depth
max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes

if isinstance(self.min_samples_leaf, numbers.Integral):
if not 1 <= self.min_samples_leaf:
raise ValueError(
"min_samples_leaf must be at least 1 or in (0, 0.5], got %s"
% self.min_samples_leaf
)
check_scalar(
self.min_samples_leaf,
name="min_samples_leaf",
target_type=numbers.Integral,
min_val=1,
)
min_samples_leaf = self.min_samples_leaf
else: # float
if not 0.0 < self.min_samples_leaf <= 0.5:
raise ValueError(
"min_samples_leaf must be at least 1 or in (0, 0.5], got %s"
% self.min_samples_leaf
)
check_scalar(
self.min_samples_leaf,
name="min_samples_leaf",
target_type=numbers.Real,
min_val=0.0,
include_boundaries="right",
)
min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))

if isinstance(self.min_samples_split, numbers.Integral):
if not 2 <= self.min_samples_split:
raise ValueError(
"min_samples_split must be an integer "
"greater than 1 or a float in (0.0, 1.0]; "
"got the integer %s"
% self.min_samples_split
)
check_scalar(
self.min_samples_split,
name="min_samples_split",
target_type=numbers.Integral,
min_val=2,
)
min_samples_split = self.min_samples_split
else: # float
if not 0.0 < self.min_samples_split <= 1.0:
raise ValueError(
"min_samples_split must be an integer "
"greater than 1 or a float in (0.0, 1.0]; "
"got the float %s"
% self.min_samples_split
)
check_scalar(
self.min_samples_split,
name="min_samples_split",
target_type=numbers.Real,
min_val=0.0,
max_val=1.0,
include_boundaries="right",
)
min_samples_split = int(ceil(self.min_samples_split * n_samples))
min_samples_split = max(2, min_samples_split)

min_samples_split = max(min_samples_split, 2 * min_samples_leaf)

check_scalar(
self.min_weight_fraction_leaf,
name="min_weight_fraction_leaf",
target_type=numbers.Real,
min_val=0.0,
max_val=0.5,
)

if isinstance(self.max_features, str):
if self.max_features == "auto":
if is_classification:
Expand All @@ -284,36 +306,51 @@ def fit(self, X, y, sample_weight=None, check_input=True):
elif self.max_features is None:
max_features = self.n_features_in_
elif isinstance(self.max_features, numbers.Integral):
check_scalar(
self.max_features,
name="max_features",
target_type=numbers.Integral,
min_val=1,
max_val=self.n_features_in_,
)
max_features = self.max_features
else: # float
check_scalar(
self.max_features,
name="max_features",
target_type=numbers.Real,
min_val=0.0,
max_val=1.0,
include_boundaries="right",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here there is a subtle ambiguity for 1 that could be an Integral or Real whether one pass 1.0 or 1.

Basically, I would expect:

  • None allows taking all features.
  • 1 as an integer will select one feature.
  • 1.0 as a real should not be included.

But I am not really sure what is the best here. Maybe @thomasjpfan or @ogrisel remember how we want to make the validation for this ambiguity.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My impression was that 1.0 means "take all the features", which is inline with the current implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at different places it seems that we do the same as stated by @thomasjpfan. I am still not a super fan that we can pass both 1 and 1.0 and that they will have different behaviour.

)
if self.max_features > 0.0:
max_features = max(1, int(self.max_features * self.n_features_in_))
else:
max_features = 0

self.max_features_ = max_features

if self.max_leaf_nodes is not None:
check_scalar(
self.max_leaf_nodes,
name="max_leaf_nodes",
target_type=numbers.Integral,
min_val=2,
)
max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes

check_scalar(
self.min_impurity_decrease,
name="min_impurity_decrease",
target_type=numbers.Real,
min_val=0.0,
)

if len(y) != n_samples:
raise ValueError(
"Number of labels=%d does not match number of samples=%d"
% (len(y), n_samples)
)
if not 0 <= self.min_weight_fraction_leaf <= 0.5:
raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
if max_depth <= 0:
raise ValueError("max_depth must be greater than zero. ")
if not (0 < max_features <= self.n_features_in_):
raise ValueError("max_features must be in (0, n_features]")
if not isinstance(max_leaf_nodes, numbers.Integral):
raise ValueError(
"max_leaf_nodes must be integral number but was %r" % max_leaf_nodes
)
if -1 < max_leaf_nodes < 2:
raise ValueError(
("max_leaf_nodes {0} must be either None or larger than 1").format(
max_leaf_nodes
)
)

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X, DOUBLE)
Expand All @@ -330,9 +367,6 @@ def fit(self, X, y, sample_weight=None, check_input=True):
else:
min_weight_leaf = self.min_weight_fraction_leaf * np.sum(sample_weight)

if self.min_impurity_decrease < 0.0:
raise ValueError("min_impurity_decrease must be greater than or equal to 0")

# Build tree
criterion = self.criterion
if not isinstance(criterion, Criterion):
Expand Down Expand Up @@ -536,9 +570,6 @@ def _prune_tree(self):
"""Prune tree using Minimal Cost-Complexity Pruning."""
check_is_fitted(self)

if self.ccp_alpha < 0.0:
raise ValueError("ccp_alpha must be greater than or equal to 0")

if self.ccp_alpha == 0.0:
return

Expand Down
Loading