Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 998e8f2

Browse files
authored
TST invalid init parameters for losses (#22407)
1 parent 4d9e005 commit 998e8f2

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

sklearn/_loss/loss.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# - SGDRegressor, SGDClassifier
1616
# - Replace link module of GLMs.
1717

18+
import numbers
1819
import numpy as np
1920
from scipy.special import xlogy
2021
from ._loss import (
@@ -34,6 +35,7 @@
3435
LogitLink,
3536
MultinomialLogit,
3637
)
38+
from ..utils import check_scalar
3739
from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper
3840
from ..utils.stats import _weighted_percentile
3941

@@ -604,11 +606,14 @@ class PinballLoss(BaseLoss):
604606
need_update_leaves_values = True
605607

606608
def __init__(self, sample_weight=None, quantile=0.5):
607-
if quantile <= 0 or quantile >= 1:
608-
raise ValueError(
609-
"PinballLoss aka quantile loss only accepts "
610-
f"0 < quantile < 1; {quantile} was given."
611-
)
609+
check_scalar(
610+
quantile,
611+
"quantile",
612+
target_type=numbers.Real,
613+
min_val=0,
614+
max_val=1,
615+
include_boundaries="neither",
616+
)
612617
super().__init__(
613618
closs=CyPinballLoss(quantile=float(quantile)),
614619
link=IdentityLink(),
@@ -725,6 +730,14 @@ class HalfTweedieLoss(BaseLoss):
725730
"""
726731

727732
def __init__(self, sample_weight=None, power=1.5):
733+
check_scalar(
734+
power,
735+
"power",
736+
target_type=numbers.Real,
737+
include_boundaries="neither",
738+
min_val=-np.inf,
739+
max_val=np.inf,
740+
)
728741
super().__init__(
729742
closs=CyHalfTweedieLoss(power=float(power)),
730743
link=LogLink(),

sklearn/_loss/tests/test_loss.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,42 @@ def test_init_gradient_and_hessian_raises(loss, params, err_msg):
10481048
gradient, hessian = loss.init_gradient_and_hessian(n_samples=5, **params)
10491049

10501050

1051+
@pytest.mark.parametrize(
1052+
"loss, params, err_type, err_msg",
1053+
[
1054+
(
1055+
PinballLoss,
1056+
{"quantile": None},
1057+
TypeError,
1058+
"quantile must be an instance of float, not NoneType.",
1059+
),
1060+
(
1061+
PinballLoss,
1062+
{"quantile": 0},
1063+
ValueError,
1064+
"quantile == 0, must be > 0.",
1065+
),
1066+
(PinballLoss, {"quantile": 1.1}, ValueError, "quantile == 1.1, must be < 1."),
1067+
(
1068+
HalfTweedieLoss,
1069+
{"power": None},
1070+
TypeError,
1071+
"power must be an instance of float, not NoneType.",
1072+
),
1073+
(
1074+
HalfTweedieLoss,
1075+
{"power": np.inf},
1076+
ValueError,
1077+
"power == inf, must be < inf.",
1078+
),
1079+
],
1080+
)
1081+
def test_loss_init_parameter_validation(loss, params, err_type, err_msg):
1082+
"""Test that loss raises errors for invalid input."""
1083+
with pytest.raises(err_type, match=err_msg):
1084+
loss(**params)
1085+
1086+
10511087
@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
10521088
def test_loss_pickle(loss):
10531089
"""Test that losses can be pickled."""

0 commit comments

Comments
 (0)