diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 0602ec77aa500..8d6979020dbf2 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1451,6 +1451,25 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): return labels +@validate_params( + { + "y_true": ["array-like", "sparse matrix"], + "y_pred": ["array-like", "sparse matrix"], + "beta": [Interval(Real, 0.0, None, closed="both")], + "labels": ["array-like", None], + "pos_label": [Real, str, "boolean", None], + "average": [ + StrOptions({"micro", "macro", "samples", "weighted", "binary"}), + None, + ], + "warn_for": [list, tuple, set], + "sample_weight": ["array-like", None], + "zero_division": [ + Options(Real, {0, 1}), + StrOptions({"warn"}), + ], + } +) def precision_recall_fscore_support( y_true, y_pred, @@ -1539,7 +1558,7 @@ def precision_recall_fscore_support( meaningful for multilabel classification where this differs from :func:`accuracy_score`). - warn_for : tuple or set, for internal use + warn_for : list, tuple or set, for internal use This determines which warnings will be made in the case that this function is being used to return only one of its metrics. @@ -1616,8 +1635,6 @@ def precision_recall_fscore_support( array([2, 2, 2])) """ _check_zero_division(zero_division) - if beta < 0: - raise ValueError("beta should be >=0 in the F-beta score") labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label) # Calculate tp_sum, pred_sum, true_sum ### diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 222f445555200..8281ec9877d21 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -386,23 +386,6 @@ def test_average_precision_score_tied_values(): assert average_precision_score(y_true, y_score) != 1.0 -@ignore_warnings -def test_precision_recall_fscore_support_errors(): - y_true, y_pred, _ = make_prediction(binary=True) - - # Bad beta - with pytest.raises(ValueError): - precision_recall_fscore_support(y_true, y_pred, beta=-0.1) - - # Bad pos_label - with pytest.raises(ValueError): - precision_recall_fscore_support(y_true, y_pred, pos_label=2, average="binary") - - # Bad average option - with pytest.raises(ValueError): - precision_recall_fscore_support([0, 1, 2], [1, 2, 0], average="mega") - - def test_precision_recall_f_unused_pos_label(): # Check warning that pos_label unused when set to non-default value # but average != 'binary'; even if data is binary. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 9b2b56cdb3eb8..afe44e3bf2723 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -130,6 +130,7 @@ def _check_function_param_validation( "sklearn.metrics.multilabel_confusion_matrix", "sklearn.metrics.mutual_info_score", "sklearn.metrics.pairwise.additive_chi2_kernel", + "sklearn.metrics.precision_recall_fscore_support", "sklearn.metrics.r2_score", "sklearn.metrics.roc_curve", "sklearn.metrics.zero_one_loss",