Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH add support for Array API to mean_pinball_loss and explained_variance_score #29978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,13 @@ Metrics
- :func:`sklearn.metrics.cluster.entropy`
- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.d2_tweedie_score`
- :func:`sklearn.metrics.explained_variance_score`
- :func:`sklearn.metrics.f1_score`
- :func:`sklearn.metrics.max_error`
- :func:`sklearn.metrics.mean_absolute_error`
- :func:`sklearn.metrics.mean_absolute_percentage_error`
- :func:`sklearn.metrics.mean_gamma_deviance`
- :func:`sklearn.metrics.mean_pinball_loss`
- :func:`sklearn.metrics.mean_poisson_deviance` (requires `enabling array API support for SciPy <https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html#using-array-api-standard-support>`_)
- :func:`sklearn.metrics.mean_squared_error`
- :func:`sklearn.metrics.mean_squared_log_error`
Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/29978.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- :func:`sklearn.metrics.explained_variance_score` and
:func:`sklearn.metrics.mean_pinball_loss` now support Array API compatible inputs.
by :user:`Virgil Chan <virchan>`
62 changes: 38 additions & 24 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def mean_absolute_error(
if multioutput == "raw_values":
return output_errors
elif multioutput == "uniform_average":
# pass None as weights to np.average: uniform mean
# pass None as weights to _average: uniform mean
multioutput = None

# Average across the outputs (if needed).
Expand Down Expand Up @@ -360,35 +360,45 @@ def mean_pinball_loss(
>>> from sklearn.metrics import mean_pinball_loss
>>> y_true = [1, 2, 3]
>>> mean_pinball_loss(y_true, [0, 2, 3], alpha=0.1)
np.float64(0.03...)
0.03...
>>> mean_pinball_loss(y_true, [1, 2, 4], alpha=0.1)
np.float64(0.3...)
0.3...
>>> mean_pinball_loss(y_true, [0, 2, 3], alpha=0.9)
np.float64(0.3...)
0.3...
>>> mean_pinball_loss(y_true, [1, 2, 4], alpha=0.9)
np.float64(0.03...)
0.03...
>>> mean_pinball_loss(y_true, y_true, alpha=0.1)
np.float64(0.0)
0.0
>>> mean_pinball_loss(y_true, y_true, alpha=0.9)
np.float64(0.0)
0.0
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)

_, y_true, y_pred, sample_weight, multioutput = (
_check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
)
)

check_consistent_length(y_true, y_pred, sample_weight)
diff = y_true - y_pred
sign = (diff >= 0).astype(diff.dtype)
sign = xp.astype(diff >= 0, diff.dtype)
loss = alpha * sign * diff - (1 - alpha) * (1 - sign) * diff
output_errors = np.average(loss, weights=sample_weight, axis=0)
output_errors = _average(loss, weights=sample_weight, axis=0)

if isinstance(multioutput, str) and multioutput == "raw_values":
return output_errors

if isinstance(multioutput, str) and multioutput == "uniform_average":
# pass None as weights to np.average: uniform mean
# pass None as weights to _average: uniform mean
multioutput = None

return np.average(output_errors, weights=multioutput)
# Average across the outputs (if needed).
# The second call to `_average` should always return
# a scalar array that we convert to a Python float to
# consistently return the same eager evaluated value.
# Therefore, `axis=None`.
return float(_average(output_errors, weights=multioutput))


@validate_params(
Expand Down Expand Up @@ -949,12 +959,12 @@ def _assemble_r2_explained_variance(
# return scores individually
return output_scores
elif multioutput == "uniform_average":
# Passing None as weights to np.average results is uniform mean
# pass None as weights to _average: uniform mean
avg_weights = None
elif multioutput == "variance_weighted":
avg_weights = denominator
if not xp.any(nonzero_denominator):
# All weights are zero, np.average would raise a ZeroDiv error.
# All weights are zero, _average would raise a ZeroDiv error.
# This only happens when all y are constant (or 1-element long)
# Since weights are all equal, fall back to uniform weights.
avg_weights = None
Expand Down Expand Up @@ -1083,28 +1093,32 @@ def explained_variance_score(
>>> explained_variance_score(y_true, y_pred, force_finite=False)
-inf
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput
xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight, multioutput)

_, y_true, y_pred, sample_weight, multioutput = (
_check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
)
)

check_consistent_length(y_true, y_pred, sample_weight)

y_diff_avg = np.average(y_true - y_pred, weights=sample_weight, axis=0)
numerator = np.average(
y_diff_avg = _average(y_true - y_pred, weights=sample_weight, axis=0)
numerator = _average(
(y_true - y_pred - y_diff_avg) ** 2, weights=sample_weight, axis=0
)

y_true_avg = np.average(y_true, weights=sample_weight, axis=0)
denominator = np.average((y_true - y_true_avg) ** 2, weights=sample_weight, axis=0)
y_true_avg = _average(y_true, weights=sample_weight, axis=0)
denominator = _average((y_true - y_true_avg) ** 2, weights=sample_weight, axis=0)

return _assemble_r2_explained_variance(
numerator=numerator,
denominator=denominator,
n_outputs=y_true.shape[1],
multioutput=multioutput,
force_finite=force_finite,
xp=get_namespace(y_true)[0],
# TODO: update once Array API support is added to explained_variance_score.
device=None,
xp=xp,
device=device,
)


Expand Down
8 changes: 8 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,10 +2084,18 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_regression_metric_multioutput,
],
cosine_similarity: [check_array_api_metric_pairwise],
explained_variance_score: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
mean_absolute_error: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
mean_pinball_loss: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
mean_squared_error: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def test_mean_pinball_loss_on_constant_predictions(distribution, target_quantile
# Check that the loss of this constant predictor is greater or equal
# than the loss of using the optimal quantile (up to machine
# precision):
assert pbl >= best_pbl - np.finfo(best_pbl.dtype).eps
assert pbl >= best_pbl - np.finfo(np.float64).eps

# Check that the value of the pinball loss matches the analytical
# formula.
Expand Down
Loading