diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 297e83173e47e..e3d46f5138fb2 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -814,6 +814,14 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): return fps, tps, y_score[threshold_idxs] +@validate_params( + { + "y_true": ["array-like"], + "probas_pred": ["array-like"], + "pos_label": [Real, str, "boolean", None], + "sample_weight": ["array-like", None], + } +) def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight=None): """Compute precision-recall pairs for different probability thresholds. @@ -839,11 +847,11 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight Parameters ---------- - y_true : ndarray of shape (n_samples,) + y_true : array-like of shape (n_samples,) True binary labels. If labels are not either {-1, 1} or {0, 1}, then pos_label should be explicitly given. - probas_pred : ndarray of shape (n_samples,) + probas_pred : array-like of shape (n_samples,) Target scores, can either be probability estimates of the positive class, or non-thresholded measure of decisions (as returned by `decision_function` on some classifiers). diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 4e13bb46ef645..a9cb675f43423 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -134,6 +134,7 @@ def _check_function_param_validation( "sklearn.metrics.multilabel_confusion_matrix", "sklearn.metrics.mutual_info_score", "sklearn.metrics.pairwise.additive_chi2_kernel", + "sklearn.metrics.precision_recall_curve", "sklearn.metrics.precision_recall_fscore_support", "sklearn.metrics.r2_score", "sklearn.metrics.roc_curve",