Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 4be28d4

Browse files
ashah002glemaitrejeremiedbb
authored
MAINT Parameter validation for sklearn.metrics.d2_pinball_score (#25414)
Co-authored-by: Guillaume Lemaitre <[email protected]> Co-authored-by: Jérémie du Boisberranger <[email protected]>
1 parent ba16dbe commit 4be28d4

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

sklearn/metrics/_regression.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,18 @@ def d2_tweedie_score(y_true, y_pred, *, sample_weight=None, power=0):
13041304
return 1 - numerator / denominator
13051305

13061306

1307+
@validate_params(
1308+
{
1309+
"y_true": ["array-like"],
1310+
"y_pred": ["array-like"],
1311+
"sample_weight": ["array-like", None],
1312+
"alpha": [Interval(Real, 0, 1, closed="both")],
1313+
"multioutput": [
1314+
StrOptions({"raw_values", "uniform_average"}),
1315+
"array-like",
1316+
],
1317+
}
1318+
)
13071319
def d2_pinball_score(
13081320
y_true, y_pred, *, sample_weight=None, alpha=0.5, multioutput="uniform_average"
13091321
):
@@ -1327,7 +1339,7 @@ def d2_pinball_score(
13271339
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
13281340
Estimated target values.
13291341
1330-
sample_weight : array-like of shape (n_samples,), optional
1342+
sample_weight : array-like of shape (n_samples,), default=None
13311343
Sample weights.
13321344
13331345
alpha : float, default=0.5
@@ -1434,15 +1446,9 @@ def d2_pinball_score(
14341446
if multioutput == "raw_values":
14351447
# return scores individually
14361448
return output_scores
1437-
elif multioutput == "uniform_average":
1449+
else: # multioutput == "uniform_average"
14381450
# passing None as weights to np.average results in uniform mean
14391451
avg_weights = None
1440-
else:
1441-
raise ValueError(
1442-
"multioutput is expected to be 'raw_values' "
1443-
"or 'uniform_average' but we got %r"
1444-
" instead." % multioutput
1445-
)
14461452
else:
14471453
avg_weights = multioutput
14481454

sklearn/metrics/tests/test_regression.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,6 @@ def test_regression_multioutput_array():
351351
with pytest.raises(ValueError, match=err_msg):
352352
mean_pinball_loss(y_true, y_pred, multioutput="variance_weighted")
353353

354-
with pytest.raises(ValueError, match=err_msg):
355-
d2_pinball_score(y_true, y_pred, multioutput="variance_weighted")
356-
357354
pbl = mean_pinball_loss(y_true, y_pred, multioutput="raw_values")
358355
mape = mean_absolute_percentage_error(y_true, y_pred, multioutput="raw_values")
359356
r = r2_score(y_true, y_pred, multioutput="raw_values")

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def _check_function_param_validation(
117117
"sklearn.metrics.cluster.contingency_matrix",
118118
"sklearn.metrics.cohen_kappa_score",
119119
"sklearn.metrics.confusion_matrix",
120+
"sklearn.metrics.d2_pinball_score",
120121
"sklearn.metrics.det_curve",
121122
"sklearn.metrics.hamming_loss",
122123
"sklearn.metrics.mean_absolute_error",

0 commit comments

Comments
 (0)