Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MNT Refactor _average_weighted_percentile to avoid double sort #31775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_xlogy as xlogy,
)
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils.stats import _averaged_weighted_percentile, _weighted_percentile
from ..utils.stats import _weighted_percentile
from ..utils.validation import (
_check_sample_weight,
_num_samples,
Expand Down Expand Up @@ -923,8 +923,8 @@ def median_absolute_error(
if sample_weight is None:
output_errors = _median(xp.abs(y_pred - y_true), axis=0)
else:
output_errors = _averaged_weighted_percentile(
xp.abs(y_pred - y_true), sample_weight=sample_weight
output_errors = _weighted_percentile(
xp.abs(y_pred - y_true), sample_weight=sample_weight, average=True
)
if isinstance(multioutput, str):
if multioutput == "raw_values":
Expand Down
19 changes: 11 additions & 8 deletions sklearn/preprocessing/_discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..base import BaseEstimator, TransformerMixin, _fit_context
from ..utils import resample
from ..utils._param_validation import Interval, Options, StrOptions
from ..utils.stats import _averaged_weighted_percentile, _weighted_percentile
from ..utils.stats import _weighted_percentile
from ..utils.validation import (
_check_feature_names_in,
_check_sample_weight,
Expand Down Expand Up @@ -357,17 +357,20 @@ def fit(self, X, y=None, sample_weight=None):
dtype=np.float64,
)
else:
# TODO: make _weighted_percentile and
# _averaged_weighted_percentile accept an array of
# TODO: make _weighted_percentile accept an array of
# quantiles instead of calling it multiple times and
# sorting the column multiple times as a result.
percentile_func = {
"inverted_cdf": _weighted_percentile,
"averaged_inverted_cdf": _averaged_weighted_percentile,
}[quantile_method]
average = (
True if quantile_method == "averaged_inverted_cdf" else False
)
bin_edges[jj] = np.asarray(
[
percentile_func(column, sample_weight, percentile_rank=p)
_weighted_percentile(
column,
sample_weight,
percentile_rank=p,
average=average,
)
for p in percentile_levels
],
dtype=np.float64,
Expand Down
72 changes: 59 additions & 13 deletions sklearn/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@
)


def _weighted_percentile(array, sample_weight, percentile_rank=50, xp=None):
"""Compute the weighted percentile with method 'inverted_cdf'.
def _weighted_percentile(
array, sample_weight, percentile_rank=50, average=False, xp=None
):
"""Compute the weighted percentile.

Uses 'inverted_cdf' method when `average=False` (default) and
'averaged_inverted_cdf' when `average=True`.

When the percentile lies between two data points of `array`, the function returns
the lower value.
Expand Down Expand Up @@ -38,6 +43,14 @@ def _weighted_percentile(array, sample_weight, percentile_rank=50, xp=None):
The probability level of the percentile to compute, in percent. Must be between
0 and 100.

average : bool, default=False
If `True`, uses the "averaged_inverted_cdf" quantile method, otherwise
defaults to "inverted_cdf". "averaged_inverted_cdf" is symmetrical with
unit `sample_weight`, such that the total of `sample_weight` below or equal to
`_weighted_percentile(percentile_rank)` is the same as the total of
`sample_weight` above or equal to `_weighted_percentile(100-percentile_rank).
This symmetry is not guaranteed with non-unit weights.

xp : array_namespace, default=None
The standard-compatible namespace for `array`. Default: infer.

Expand Down Expand Up @@ -101,22 +114,55 @@ def _weighted_percentile(array, sample_weight, percentile_rank=50, xp=None):
for feature_idx in range(weight_cdf.shape[0])
],
)
# In rare cases, `percentile_indices` equals to `sorted_idx.shape[0]`
# `percentile_indices` may be equal to `sorted_idx.shape[0]` due to floating
# point error (see #11813)
max_idx = sorted_idx.shape[0] - 1
percentile_indices = xp.clip(percentile_indices, 0, max_idx)

col_indices = xp.arange(array.shape[1], device=device)
percentile_in_sorted = sorted_idx[percentile_indices, col_indices]

result = array[percentile_in_sorted, col_indices]
if average:
# From Hyndman and Fan (1996), `fraction_above` is `g`
fraction_above = (
weight_cdf[col_indices, percentile_indices] - adjusted_percentile_rank
)
# Alternatively, could use
# `is_exact_percentile = fraction_above <= xp.finfo(floating_dtype).eps`
# but that seems harder to read
is_fraction_above = fraction_above > xp.finfo(floating_dtype).eps
percentile_plus_one_indices = xp.clip(percentile_indices + 1, 0, max_idx)
percentile_plus_one_in_sorted = sorted_idx[
percentile_plus_one_indices, col_indices
]
# Handle case when when next index ('plus one') has sample weight of 0
zero_weight_cols = col_indices[
sample_weight[percentile_plus_one_in_sorted, col_indices] == 0
]
for col_idx in zero_weight_cols:
cdf_val = weight_cdf[col_idx, percentile_indices[col_idx]]
# Search for next index where `weighted_cdf` is greater
next_index = xp.searchsorted(
weight_cdf[col_idx, ...], cdf_val, side="right"
)
# Handle case where there are trailing 0 sample weight samples
# and `percentile_indices` is already max index
if next_index >= max_idx:
# use original `percentile_indices` again
next_index = percentile_indices[col_idx]

percentile_plus_one_in_sorted[col_idx] = sorted_idx[next_index, col_idx]

result = xp.where(
is_fraction_above,
array[percentile_in_sorted, col_indices],
Copy link
Member Author

@lucyleeow lucyleeow Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially thought this should be percentile_plus_one_in_sorted as from the paper, when g>0, $\gamma=1$, but searchsorted defaults to left (equals is on the right), whereas the paper defined j <= pn < j+1 but searchsorted effectively gives i-1 < pn <= i whereas the paper had j <= pn < j+1. This means that when pn is greater than the LHS, searchsorted's i equals j+1, from the paper.

When the quantile exactly matches an index, searchsorted's i equals j, from the paper (as the equals is on opposite sides in paper vs searchsorted).

(
array[percentile_in_sorted, col_indices]
+ array[percentile_plus_one_in_sorted, col_indices]
)
/ 2,
)
else:
result = array[percentile_in_sorted, col_indices]

return result[0] if n_dim == 1 else result


# TODO: refactor to do the symmetrisation inside _weighted_percentile to avoid
# sorting the input array twice.
def _averaged_weighted_percentile(array, sample_weight, percentile_rank=50, xp=None):
return (
_weighted_percentile(array, sample_weight, percentile_rank, xp=xp)
- _weighted_percentile(-array, sample_weight, 100 - percentile_rank, xp=xp)
) / 2
Loading