Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7970d78

Browse files
nwzishjeremiedbb
andauthored
MAINT Parameters validation for sklearn.metrics.f1_score (#25557)
Co-authored-by: Jérémie du Boisberranger <[email protected]>
1 parent cd25abe commit 7970d78

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

sklearn/metrics/_classification.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# License: BSD 3 clause
2424

2525

26+
from numbers import Integral, Real
2627
import warnings
2728
import numpy as np
2829

@@ -40,7 +41,7 @@
4041
from ..utils.multiclass import type_of_target
4142
from ..utils.validation import _num_samples
4243
from ..utils.sparsefuncs import count_nonzero
43-
from ..utils._param_validation import StrOptions, validate_params
44+
from ..utils._param_validation import StrOptions, Options, validate_params
4445
from ..exceptions import UndefinedMetricWarning
4546

4647
from ._base import _check_pos_label_consistency
@@ -1038,6 +1039,23 @@ def zero_one_loss(y_true, y_pred, *, normalize=True, sample_weight=None):
10381039
return n_samples - score
10391040

10401041

1042+
@validate_params(
1043+
{
1044+
"y_true": ["array-like", "sparse matrix"],
1045+
"y_pred": ["array-like", "sparse matrix"],
1046+
"labels": ["array-like", None],
1047+
"pos_label": [Real, str, "boolean", None],
1048+
"average": [
1049+
StrOptions({"micro", "macro", "samples", "weighted", "binary"}),
1050+
None,
1051+
],
1052+
"sample_weight": ["array-like", None],
1053+
"zero_division": [
1054+
Options(Integral, {0, 1}),
1055+
StrOptions({"warn"}),
1056+
],
1057+
}
1058+
)
10411059
def f1_score(
10421060
y_true,
10431061
y_pred,
@@ -1083,7 +1101,7 @@ def f1_score(
10831101
.. versionchanged:: 0.17
10841102
Parameter `labels` improved for multiclass problem.
10851103
1086-
pos_label : int, float, bool or str, default=1
1104+
pos_label : int, float, bool, str or None, default=1
10871105
The class to report if ``average='binary'`` and the data is binary.
10881106
If the data are multiclass or multilabel, this will be ignored;
10891107
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def _check_function_param_validation(
120120
"sklearn.metrics.confusion_matrix",
121121
"sklearn.metrics.d2_pinball_score",
122122
"sklearn.metrics.det_curve",
123+
"sklearn.metrics.f1_score",
123124
"sklearn.metrics.hamming_loss",
124125
"sklearn.metrics.mean_absolute_error",
125126
"sklearn.metrics.mean_squared_error",

0 commit comments

Comments
 (0)