Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b80138f

Browse files
genvalenglemaitrethomasjpfan
authored
MAINT Use check_scalar in BaseGradientBoosting (#21632)
Co-authored-by: Guillaume Lemaitre <[email protected]> Co-authored-by: Thomas J. Fan <[email protected]>
1 parent e977238 commit b80138f

File tree

5 files changed

+203
-64
lines changed

5 files changed

+203
-64
lines changed

doc/whats_new/v1.1.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,11 @@ Changelog
246246
:pr:`20811`, :pr:`20567` and :pr:`21814` by
247247
:user:`Christian Lorentzen <lorentzenchr>`.
248248

249+
- |Fix| Change the parameter `validation_fraction` in
250+
:class:`ensemble.BaseGradientBoosting` so that an error is raised if anything
251+
other than a float is passed in as an argument.
252+
:pr:`21632` by :user:`Genesis Valencia <genvalen>`
253+
249254
- |API| Changed the default of :func:`max_features` to 1.0 for
250255
:class:`ensemble.RandomForestRegressor` and to `"sqrt"` for
251256
:class:`ensemble.RandomForestClassifier`. Note that these give the same fit

sklearn/ensemble/_gb.py

Lines changed: 90 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
from ..utils import check_random_state
5252
from ..utils import check_array
53+
from ..utils import check_scalar
5354
from ..utils import column_or_1d
5455
from ..utils.validation import check_is_fitted, _check_sample_weight
5556
from ..utils.multiclass import check_classification_targets
@@ -265,21 +266,28 @@ def _fit_stage(
265266

266267
def _check_params(self):
267268
"""Check validity of parameters and raise ValueError if not valid."""
268-
if self.n_estimators <= 0:
269-
raise ValueError(
270-
"n_estimators must be greater than 0 but was %r" % self.n_estimators
271-
)
272269

273-
if self.learning_rate <= 0.0:
274-
raise ValueError(
275-
"learning_rate must be greater than 0 but was %r" % self.learning_rate
276-
)
270+
check_scalar(
271+
self.learning_rate,
272+
name="learning_rate",
273+
target_type=numbers.Real,
274+
min_val=0.0,
275+
include_boundaries="neither",
276+
)
277+
278+
check_scalar(
279+
self.n_estimators,
280+
name="n_estimators",
281+
target_type=numbers.Integral,
282+
min_val=1,
283+
include_boundaries="left",
284+
)
277285

278286
if (
279287
self.loss not in self._SUPPORTED_LOSS
280288
or self.loss not in _gb_losses.LOSS_FUNCTIONS
281289
):
282-
raise ValueError("Loss '{0:s}' not supported. ".format(self.loss))
290+
raise ValueError(f"Loss {self.loss!r} not supported. ")
283291

284292
# TODO: Remove in v1.2
285293
if self.loss == "ls":
@@ -313,8 +321,14 @@ def _check_params(self):
313321
else:
314322
self.loss_ = loss_class()
315323

316-
if not (0.0 < self.subsample <= 1.0):
317-
raise ValueError("subsample must be in (0,1] but was %r" % self.subsample)
324+
check_scalar(
325+
self.subsample,
326+
name="subsample",
327+
target_type=numbers.Real,
328+
min_val=0.0,
329+
max_val=1.0,
330+
include_boundaries="right",
331+
)
318332

319333
if self.init is not None:
320334
# init must be an estimator or 'zero'
@@ -323,11 +337,17 @@ def _check_params(self):
323337
elif not (isinstance(self.init, str) and self.init == "zero"):
324338
raise ValueError(
325339
"The init parameter must be an estimator or 'zero'. "
326-
"Got init={}".format(self.init)
340+
f"Got init={self.init!r}"
327341
)
328342

329-
if not (0.0 < self.alpha < 1.0):
330-
raise ValueError("alpha must be in (0.0, 1.0) but was %r" % self.alpha)
343+
check_scalar(
344+
self.alpha,
345+
name="alpha",
346+
target_type=numbers.Real,
347+
min_val=0.0,
348+
max_val=1.0,
349+
include_boundaries="neither",
350+
)
331351

332352
if isinstance(self.max_features, str):
333353
if self.max_features == "auto":
@@ -341,29 +361,66 @@ def _check_params(self):
341361
max_features = max(1, int(np.log2(self.n_features_in_)))
342362
else:
343363
raise ValueError(
344-
"Invalid value for max_features: %r. "
345-
"Allowed string values are 'auto', 'sqrt' "
346-
"or 'log2'."
347-
% self.max_features
364+
f"Invalid value for max_features: {self.max_features!r}. "
365+
"Allowed string values are 'auto', 'sqrt' or 'log2'."
348366
)
349367
elif self.max_features is None:
350368
max_features = self.n_features_in_
351369
elif isinstance(self.max_features, numbers.Integral):
370+
check_scalar(
371+
self.max_features,
372+
name="max_features",
373+
target_type=numbers.Integral,
374+
min_val=1,
375+
include_boundaries="left",
376+
)
352377
max_features = self.max_features
353378
else: # float
354-
if 0.0 < self.max_features <= 1.0:
355-
max_features = max(int(self.max_features * self.n_features_in_), 1)
356-
else:
357-
raise ValueError("max_features must be in (0, n_features]")
379+
check_scalar(
380+
self.max_features,
381+
name="max_features",
382+
target_type=numbers.Real,
383+
min_val=0.0,
384+
max_val=1.0,
385+
include_boundaries="right",
386+
)
387+
max_features = max(1, int(self.max_features * self.n_features_in_))
358388

359389
self.max_features_ = max_features
360390

361-
if not isinstance(self.n_iter_no_change, (numbers.Integral, type(None))):
362-
raise ValueError(
363-
"n_iter_no_change should either be None or an integer. %r was passed"
364-
% self.n_iter_no_change
391+
check_scalar(
392+
self.verbose,
393+
name="verbose",
394+
target_type=(numbers.Integral, np.bool_),
395+
min_val=0,
396+
)
397+
398+
check_scalar(
399+
self.validation_fraction,
400+
name="validation_fraction",
401+
target_type=numbers.Real,
402+
min_val=0.0,
403+
max_val=1.0,
404+
include_boundaries="neither",
405+
)
406+
407+
if self.n_iter_no_change is not None:
408+
check_scalar(
409+
self.n_iter_no_change,
410+
name="n_iter_no_change",
411+
target_type=numbers.Integral,
412+
min_val=1,
413+
include_boundaries="left",
365414
)
366415

416+
check_scalar(
417+
self.tol,
418+
name="tol",
419+
target_type=numbers.Real,
420+
min_val=0.0,
421+
include_boundaries="neither",
422+
)
423+
367424
def _init_state(self):
368425
"""Initialize model state and allocate model state data structures."""
369426

@@ -477,6 +534,11 @@ def fit(self, X, y, sample_weight=None, monitor=None):
477534
)
478535

479536
# if not warmstart - clear the estimator state
537+
check_scalar(
538+
self.warm_start,
539+
name="warm_start",
540+
target_type=(numbers.Integral, np.bool_),
541+
)
480542
if not self.warm_start:
481543
self._clear_state()
482544

@@ -499,6 +561,8 @@ def fit(self, X, y, sample_weight=None, monitor=None):
499561
else:
500562
y = self._validate_y(y)
501563

564+
self._check_params()
565+
502566
if self.n_iter_no_change is not None:
503567
stratify = y if is_classifier(self) else None
504568
X, X_val, y, y_val, sample_weight, sample_weight_val = train_test_split(
@@ -523,8 +587,6 @@ def fit(self, X, y, sample_weight=None, monitor=None):
523587
else:
524588
X_val = y_val = sample_weight_val = None
525589

526-
self._check_params()
527-
528590
if not self._is_initialized():
529591
# init state
530592
self._init_state()

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 107 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -80,60 +80,131 @@ def test_classification_toy(loss):
8080
@pytest.mark.parametrize(
8181
"params, err_type, err_msg",
8282
[
83-
({"n_estimators": 0}, ValueError, "n_estimators must be greater than 0"),
84-
({"n_estimators": -1}, ValueError, "n_estimators must be greater than 0"),
85-
({"learning_rate": 0}, ValueError, "learning_rate must be greater than 0"),
86-
({"learning_rate": -1.0}, ValueError, "learning_rate must be greater than 0"),
83+
({"learning_rate": 0}, ValueError, "learning_rate == 0, must be > 0.0"),
84+
(
85+
{"learning_rate": "foo"},
86+
TypeError,
87+
"learning_rate must be an instance of <class 'numbers.Real'>",
88+
),
89+
({"n_estimators": 0}, ValueError, "n_estimators == 0, must be >= 1"),
90+
(
91+
{"n_estimators": 1.5},
92+
TypeError,
93+
"n_estimators must be an instance of <class 'numbers.Integral'>,",
94+
),
8795
({"loss": "foobar"}, ValueError, "Loss 'foobar' not supported"),
96+
({"subsample": 0.0}, ValueError, "subsample == 0.0, must be > 0.0"),
97+
({"subsample": 1.1}, ValueError, "subsample == 1.1, must be <= 1.0"),
8898
(
89-
{"min_samples_split": 0.0},
90-
ValueError,
91-
"min_samples_split == 0.0, must be > 0.0",
99+
{"subsample": "foo"},
100+
TypeError,
101+
"subsample must be an instance of <class 'numbers.Real'>",
102+
),
103+
({"init": {}}, ValueError, "The init parameter must be an estimator or 'zero'"),
104+
({"max_features": 0}, ValueError, "max_features == 0, must be >= 1"),
105+
({"max_features": 0.0}, ValueError, "max_features == 0.0, must be > 0.0"),
106+
({"max_features": 1.1}, ValueError, "max_features == 1.1, must be <= 1.0"),
107+
({"max_features": "foobar"}, ValueError, "Invalid value for max_features."),
108+
({"verbose": -1}, ValueError, "verbose == -1, must be >= 0"),
109+
(
110+
{"verbose": "foo"},
111+
TypeError,
112+
"verbose must be an instance of",
92113
),
114+
({"warm_start": "foo"}, TypeError, "warm_start must be an instance of"),
93115
(
94-
{"min_samples_split": -1.0},
116+
{"validation_fraction": 0.0},
95117
ValueError,
96-
"min_samples_split == -1.0, must be > 0.0",
118+
"validation_fraction == 0.0, must be > 0.0",
97119
),
98120
(
99-
{"min_samples_split": 1.1},
121+
{"validation_fraction": 1.0},
100122
ValueError,
101-
"min_samples_split == 1.1, must be <= 1.0.",
123+
"validation_fraction == 1.0, must be < 1.0",
124+
),
125+
(
126+
{"validation_fraction": "foo"},
127+
TypeError,
128+
"validation_fraction must be an instance of <class 'numbers.Real'>",
129+
),
130+
({"n_iter_no_change": 0}, ValueError, "n_iter_no_change == 0, must be >= 1"),
131+
(
132+
{"n_iter_no_change": 1.5},
133+
TypeError,
134+
"n_iter_no_change must be an instance of <class 'numbers.Integral'>,",
135+
),
136+
({"tol": 0.0}, ValueError, "tol == 0.0, must be > 0.0"),
137+
(
138+
{"tol": "foo"},
139+
TypeError,
140+
"tol must be an instance of <class 'numbers.Real'>,",
102141
),
142+
# The following parameters are checked in BaseDecisionTree
103143
({"min_samples_leaf": 0}, ValueError, "min_samples_leaf == 0, must be >= 1"),
144+
({"min_samples_leaf": 0.0}, ValueError, "min_samples_leaf == 0.0, must be > 0"),
145+
(
146+
{"min_samples_leaf": "foo"},
147+
TypeError,
148+
"min_samples_leaf must be an instance of <class 'numbers.Real'>",
149+
),
150+
({"min_samples_split": 1}, ValueError, "min_samples_split == 1, must be >= 2"),
104151
(
105-
{"min_samples_leaf": -1.0},
152+
{"min_samples_split": 0.0},
106153
ValueError,
107-
"min_samples_leaf == -1.0, must be > 0.0.",
154+
"min_samples_split == 0.0, must be > 0.0",
108155
),
109156
(
110-
{"min_weight_fraction_leaf": -1.0},
157+
{"min_samples_split": 1.1},
111158
ValueError,
112-
"min_weight_fraction_leaf == -1.0, must be >= 0",
159+
"min_samples_split == 1.1, must be <= 1.0",
113160
),
114161
(
115-
{"min_weight_fraction_leaf": 0.6},
162+
{"min_samples_split": "foo"},
163+
TypeError,
164+
"min_samples_split must be an instance of <class 'numbers.Real'>",
165+
),
166+
(
167+
{"min_weight_fraction_leaf": -1},
116168
ValueError,
117-
"min_weight_fraction_leaf == 0.6, must be <= 0.5.",
169+
"min_weight_fraction_leaf == -1, must be >= 0.0",
118170
),
119-
({"subsample": 0.0}, ValueError, r"subsample must be in \(0,1\]"),
120-
({"subsample": 1.1}, ValueError, r"subsample must be in \(0,1\]"),
121-
({"subsample": -0.1}, ValueError, r"subsample must be in \(0,1\]"),
122-
({"max_depth": -0.1}, TypeError, "max_depth must be an instance of"),
123-
({"max_depth": 0}, ValueError, "max_depth == 0, must be >= 1."),
124-
({"init": {}}, ValueError, "The init parameter must be an estimator or 'zero'"),
125-
({"max_features": "invalid"}, ValueError, "Invalid value for max_features:"),
126-
({"max_features": 0}, ValueError, "max_features == 0, must be >= 1"),
127-
({"max_features": 100}, ValueError, "max_features == 100, must be <="),
128171
(
129-
{"max_features": -0.1},
172+
{"min_weight_fraction_leaf": 0.6},
130173
ValueError,
131-
r"max_features must be in \(0, n_features\]",
174+
"min_weight_fraction_leaf == 0.6, must be <= 0.5",
132175
),
133176
(
134-
{"n_iter_no_change": "invalid"},
177+
{"min_weight_fraction_leaf": "foo"},
178+
TypeError,
179+
"min_weight_fraction_leaf must be an instance of <class 'numbers.Real'>",
180+
),
181+
({"max_leaf_nodes": 0}, ValueError, "max_leaf_nodes == 0, must be >= 2"),
182+
(
183+
{"max_leaf_nodes": 1.5},
184+
TypeError,
185+
"max_leaf_nodes must be an instance of <class 'numbers.Integral'>",
186+
),
187+
({"max_depth": -1}, ValueError, "max_depth == -1, must be >= 1"),
188+
(
189+
{"max_depth": 1.1},
190+
TypeError,
191+
"max_depth must be an instance of <class 'numbers.Integral'>",
192+
),
193+
(
194+
{"min_impurity_decrease": -1},
135195
ValueError,
136-
"n_iter_no_change should either be",
196+
"min_impurity_decrease == -1, must be >= 0.0",
197+
),
198+
(
199+
{"min_impurity_decrease": "foo"},
200+
TypeError,
201+
"min_impurity_decrease must be an instance of <class 'numbers.Real'>",
202+
),
203+
({"ccp_alpha": -1.0}, ValueError, "ccp_alpha == -1.0, must be >= 0.0"),
204+
(
205+
{"ccp_alpha": "foo"},
206+
TypeError,
207+
"ccp_alpha must be an instance of <class 'numbers.Real'>",
137208
),
138209
({"criterion": "mae"}, ValueError, "criterion='mae' is not supported."),
139210
],
@@ -158,8 +229,10 @@ def test_gbdt_parameter_checks(GradientBoosting, X, y, params, err_type, err_msg
158229
@pytest.mark.parametrize(
159230
"params, err_msg",
160231
[
161-
({"loss": "huber", "alpha": 1.2}, r"alpha must be in \(0.0, 1.0\)"),
162-
({"loss": "quantile", "alpha": 1.2}, r"alpha must be in \(0.0, 1.0\)"),
232+
({"loss": "huber", "alpha": 0.0}, "alpha == 0.0, must be > 0.0"),
233+
({"loss": "quantile", "alpha": 0.0}, "alpha == 0.0, must be > 0.0"),
234+
({"loss": "huber", "alpha": 1.2}, "alpha == 1.2, must be < 1.0"),
235+
({"loss": "quantile", "alpha": 1.2}, "alpha == 1.2, must be < 1.0"),
163236
],
164237
)
165238
def test_gbdt_loss_alpha_error(params, err_msg):
@@ -1389,7 +1462,7 @@ def test_early_stopping_n_classes():
13891462
X = [[1]] * 10
13901463
y = [0, 0] + [1] * 8 # only 2 negative class over 10 samples
13911464
gb = GradientBoostingClassifier(
1392-
n_iter_no_change=5, random_state=0, validation_fraction=8
1465+
n_iter_no_change=5, random_state=0, validation_fraction=0.8
13931466
)
13941467
with pytest.raises(
13951468
ValueError, match="The training data after the early stopping split"
@@ -1398,7 +1471,7 @@ def test_early_stopping_n_classes():
13981471

13991472
# No error if we let training data be big enough
14001473
gb = GradientBoostingClassifier(
1401-
n_iter_no_change=5, random_state=0, validation_fraction=4
1474+
n_iter_no_change=5, random_state=0, validation_fraction=0.4
14021475
)
14031476

14041477

0 commit comments

Comments
 (0)