Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ebab121

Browse files
TheBicPenglemaitre
andauthored
MAINT Use check_scalar to validate scalar in: BayesianRidge (#23051)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent f5871a3 commit ebab121

File tree

2 files changed

+185
-15
lines changed

2 files changed

+185
-15
lines changed

sklearn/linear_model/_bayes.py

Lines changed: 100 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
# License: BSD 3 clause
77

88
from math import log
9+
import numbers
910
import numpy as np
1011
from scipy import linalg
1112

1213
from ._base import LinearModel, _preprocess_data, _rescale_data
1314
from ..base import RegressorMixin
1415
from ._base import _deprecate_normalize
1516
from ..utils.extmath import fast_logdet
17+
from ..utils import check_scalar
1618
from scipy.linalg import pinvh
1719
from ..utils.validation import _check_sample_weight
1820

@@ -205,6 +207,103 @@ def __init__(
205207
self.copy_X = copy_X
206208
self.verbose = verbose
207209

210+
def _check_params(self):
211+
"""Check validity of parameters and raise ValueError
212+
or TypeError if not valid."""
213+
214+
check_scalar(
215+
self.n_iter,
216+
name="n_iter",
217+
target_type=numbers.Integral,
218+
min_val=1,
219+
)
220+
221+
check_scalar(
222+
self.tol,
223+
name="tol",
224+
target_type=numbers.Real,
225+
min_val=0.0,
226+
include_boundaries="neither",
227+
)
228+
229+
check_scalar(
230+
self.alpha_1,
231+
name="alpha_1",
232+
target_type=numbers.Real,
233+
min_val=0.0,
234+
include_boundaries="left",
235+
)
236+
237+
check_scalar(
238+
self.alpha_2,
239+
name="alpha_2",
240+
target_type=numbers.Real,
241+
min_val=0.0,
242+
include_boundaries="left",
243+
)
244+
245+
check_scalar(
246+
self.lambda_1,
247+
name="lambda_1",
248+
target_type=numbers.Real,
249+
min_val=0.0,
250+
include_boundaries="left",
251+
)
252+
253+
check_scalar(
254+
self.lambda_2,
255+
name="lambda_2",
256+
target_type=numbers.Real,
257+
min_val=0.0,
258+
include_boundaries="left",
259+
)
260+
261+
if self.alpha_init is not None:
262+
check_scalar(
263+
self.alpha_init,
264+
name="alpha_init",
265+
target_type=numbers.Real,
266+
include_boundaries="neither",
267+
)
268+
269+
if self.lambda_init is not None:
270+
check_scalar(
271+
self.lambda_init,
272+
name="lambda_init",
273+
target_type=numbers.Real,
274+
include_boundaries="neither",
275+
)
276+
277+
check_scalar(
278+
self.compute_score,
279+
name="compute_score",
280+
target_type=(np.bool_, bool),
281+
)
282+
283+
check_scalar(
284+
self.fit_intercept,
285+
name="fit_intercept",
286+
target_type=(np.bool_, bool),
287+
)
288+
289+
self._normalize = _deprecate_normalize(
290+
self.normalize, default=False, estimator_name=self.__class__.__name__
291+
)
292+
293+
check_scalar(
294+
self.copy_X,
295+
name="copy_X",
296+
target_type=(np.bool_, bool),
297+
)
298+
299+
check_scalar(
300+
self.verbose,
301+
name="verbose",
302+
target_type=(numbers.Integral, np.bool_, bool),
303+
min_val=0,
304+
max_val=1,
305+
)
306+
208307
def fit(self, X, y, sample_weight=None):
209308
"""Fit the model.
210309
@@ -226,16 +325,7 @@ def fit(self, X, y, sample_weight=None):
226325
self : object
227326
Returns the instance itself.
228327
"""
229-
self._normalize = _deprecate_normalize(
230-
self.normalize, default=False, estimator_name=self.__class__.__name__
231-
)
232-
233-
if self.n_iter < 1:
234-
raise ValueError(
235-
"n_iter should be greater than or equal to 1. Got {!r}.".format(
236-
self.n_iter
237-
)
238-
)
328+
self._check_params()
239329

240330
X, y = self._validate_data(X, y, dtype=[np.float64, np.float32], y_numeric=True)
241331

sklearn/linear_model/tests/test_bayes.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,93 @@
2121
diabetes = datasets.load_diabetes()
2222

2323

24-
def test_n_iter():
25-
"""Check value of n_iter."""
24+
@pytest.mark.parametrize(
25+
"params, err_type, err_msg",
26+
[
27+
({"n_iter": 0}, ValueError, "n_iter == 0, must be >= 1."),
28+
({"n_iter": 2.5}, TypeError, "n_iter must be an instance of int, not float."),
29+
({"tol": -1}, ValueError, "tol == -1, must be > 0"),
30+
({"tol": "-1"}, TypeError, "tol must be an instance of float, not str."),
31+
(
32+
{"alpha_1": "-1"},
33+
TypeError,
34+
"alpha_1 must be an instance of float, not str.",
35+
),
36+
(
37+
{"alpha_2": "-1"},
38+
TypeError,
39+
"alpha_2 must be an instance of float, not str.",
40+
),
41+
(
42+
{"lambda_1": "-1"},
43+
TypeError,
44+
"lambda_1 must be an instance of float, not str.",
45+
),
46+
(
47+
{"lambda_2": "-1"},
48+
TypeError,
49+
"lambda_2 must be an instance of float, not str.",
50+
),
51+
(
52+
{"alpha_init": "-1"},
53+
TypeError,
54+
"alpha_init must be an instance of float, not str.",
55+
),
56+
(
57+
{"lambda_init": "-1"},
58+
TypeError,
59+
"lambda_init must be an instance of float, not str.",
60+
),
61+
(
62+
{"compute_score": 2},
63+
TypeError,
64+
"compute_score must be an instance of {numpy.bool_, bool}, not int.",
65+
),
66+
(
67+
{"compute_score": 0.5},
68+
TypeError,
69+
"compute_score must be an instance of {numpy.bool_, bool}, not float.",
70+
),
71+
(
72+
{"fit_intercept": 2},
73+
TypeError,
74+
"fit_intercept must be an instance of {numpy.bool_, bool}, not int.",
75+
),
76+
(
77+
{"fit_intercept": 0.5},
78+
TypeError,
79+
"fit_intercept must be an instance of {numpy.bool_, bool}, not float.",
80+
),
81+
(
82+
{"normalize": -1},
83+
ValueError,
84+
"Leave 'normalize' to its default value or set it to True or False",
85+
),
86+
(
87+
{"copy_X": 2},
88+
TypeError,
89+
"copy_X must be an instance of {numpy.bool_, bool}, not int.",
90+
),
91+
(
92+
{"copy_X": 0.5},
93+
TypeError,
94+
"copy_X must be an instance of {numpy.bool_, bool}, not float.",
95+
),
96+
({"verbose": -1}, ValueError, "verbose == -1, must be >= 0"),
97+
({"verbose": 2}, ValueError, "verbose == 2, must be <= 1"),
98+
(
99+
{"verbose": 0.5},
100+
TypeError,
101+
"verbose must be an instance of {int, numpy.bool_, bool}, not float.",
102+
),
103+
],
104+
)
105+
def test_bayesian_ridge_scalar_params_validation(params, err_type, err_msg):
106+
"""Check the scalar parameters of BayesianRidge."""
26107
X = np.array([[1], [2], [6], [8], [10]])
27108
y = np.array([1, 2, 6, 8, 10])
28-
clf = BayesianRidge(n_iter=0)
29-
msg = "n_iter should be greater than or equal to 1."
30-
with pytest.raises(ValueError, match=msg):
109+
clf = BayesianRidge(**params)
110+
with pytest.raises(err_type, match=err_msg):
31111
clf.fit(X, y)
32112

33113

0 commit comments

Comments
 (0)