From bd7dfcc3031aa9b6f17cd6e568caf53cbb39bf32 Mon Sep 17 00:00:00 2001 From: Ansam Zedan Date: Mon, 3 Apr 2023 20:02:32 +0200 Subject: [PATCH 1/2] validation --- sklearn/metrics/_regression.py | 12 +++++++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index b19fc2e7f3f70..a1f5ce0843f83 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -1499,7 +1499,17 @@ def d2_pinball_score( return np.average(output_scores, weights=avg_weights) - +@validate_params( + { + "y_true": ["array-like"], + "y_pred": ["array-like"], + "sample_weight": ["array-like", None], + "multioutput": [ + StrOptions({"raw_values", "uniform_average"}), + "array-like", + ], + } +) def d2_absolute_error_score( y_true, y_pred, *, sample_weight=None, multioutput="uniform_average" ): diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index b2d6b0da4f379..fe15f7ed4c31e 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -176,6 +176,7 @@ def _check_function_param_validation( "sklearn.metrics.cohen_kappa_score", "sklearn.metrics.confusion_matrix", "sklearn.metrics.coverage_error", + "sklearn.metrics.d2_absolute_error_score", "sklearn.metrics.d2_pinball_score", "sklearn.metrics.d2_tweedie_score", "sklearn.metrics.dcg_score", From 774e786a6b1cc9b72c0789475ac0854216d8656a Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 4 Apr 2023 13:50:34 +0200 Subject: [PATCH 2/2] address review comments --- sklearn/metrics/_regression.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index a1f5ce0843f83..43e3198e9349d 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -1499,6 +1499,7 @@ def d2_pinball_score( return np.average(output_scores, weights=avg_weights) + @validate_params( { "y_true": ["array-like"], @@ -1514,8 +1515,7 @@ def d2_absolute_error_score( y_true, y_pred, *, sample_weight=None, multioutput="uniform_average" ): """ - :math:`D^2` regression score function, \ - fraction of absolute error explained. + :math:`D^2` regression score function, fraction of absolute error explained. Best possible score is 1.0 and it can be negative (because the model can be arbitrarily worse). A model that always uses the empirical median of `y_true` @@ -1534,7 +1534,7 @@ def d2_absolute_error_score( y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs) Estimated target values. - sample_weight : array-like of shape (n_samples,), optional + sample_weight : array-like of shape (n_samples,), default=None Sample weights. multioutput : {'raw_values', 'uniform_average'} or array-like of shape \