diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 3e0148345ffa1..618a64e7d2848 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -28,7 +28,7 @@ _xlogy as xlogy, ) from ..utils._param_validation import Interval, StrOptions, validate_params -from ..utils.stats import _averaged_weighted_percentile, _weighted_percentile +from ..utils.stats import _weighted_percentile from ..utils.validation import ( _check_sample_weight, _num_samples, @@ -923,8 +923,8 @@ def median_absolute_error( if sample_weight is None: output_errors = _median(xp.abs(y_pred - y_true), axis=0) else: - output_errors = _averaged_weighted_percentile( - xp.abs(y_pred - y_true), sample_weight=sample_weight + output_errors = _weighted_percentile( + xp.abs(y_pred - y_true), sample_weight=sample_weight, average=True ) if isinstance(multioutput, str): if multioutput == "raw_values": diff --git a/sklearn/preprocessing/_discretization.py b/sklearn/preprocessing/_discretization.py index ef5081080bda1..2513ceee7938e 100644 --- a/sklearn/preprocessing/_discretization.py +++ b/sklearn/preprocessing/_discretization.py @@ -10,7 +10,7 @@ from ..base import BaseEstimator, TransformerMixin, _fit_context from ..utils import resample from ..utils._param_validation import Interval, Options, StrOptions -from ..utils.stats import _averaged_weighted_percentile, _weighted_percentile +from ..utils.stats import _weighted_percentile from ..utils.validation import ( _check_feature_names_in, _check_sample_weight, @@ -357,17 +357,20 @@ def fit(self, X, y=None, sample_weight=None): dtype=np.float64, ) else: - # TODO: make _weighted_percentile and - # _averaged_weighted_percentile accept an array of + # TODO: make _weighted_percentile accept an array of # quantiles instead of calling it multiple times and # sorting the column multiple times as a result. - percentile_func = { - "inverted_cdf": _weighted_percentile, - "averaged_inverted_cdf": _averaged_weighted_percentile, - }[quantile_method] + average = ( + True if quantile_method == "averaged_inverted_cdf" else False + ) bin_edges[jj] = np.asarray( [ - percentile_func(column, sample_weight, percentile_rank=p) + _weighted_percentile( + column, + sample_weight, + percentile_rank=p, + average=average, + ) for p in percentile_levels ], dtype=np.float64, diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 66179e5ea3aba..34a76f18a7514 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -7,8 +7,13 @@ ) -def _weighted_percentile(array, sample_weight, percentile_rank=50, xp=None): - """Compute the weighted percentile with method 'inverted_cdf'. +def _weighted_percentile( + array, sample_weight, percentile_rank=50, average=False, xp=None +): + """Compute the weighted percentile. + + Uses 'inverted_cdf' method when `average=False` (default) and + 'averaged_inverted_cdf' when `average=True`. When the percentile lies between two data points of `array`, the function returns the lower value. @@ -38,6 +43,14 @@ def _weighted_percentile(array, sample_weight, percentile_rank=50, xp=None): The probability level of the percentile to compute, in percent. Must be between 0 and 100. + average : bool, default=False + If `True`, uses the "averaged_inverted_cdf" quantile method, otherwise + defaults to "inverted_cdf". "averaged_inverted_cdf" is symmetrical with + unit `sample_weight`, such that the total of `sample_weight` below or equal to + `_weighted_percentile(percentile_rank)` is the same as the total of + `sample_weight` above or equal to `_weighted_percentile(100-percentile_rank). + This symmetry is not guaranteed with non-unit weights. + xp : array_namespace, default=None The standard-compatible namespace for `array`. Default: infer. @@ -101,22 +114,55 @@ def _weighted_percentile(array, sample_weight, percentile_rank=50, xp=None): for feature_idx in range(weight_cdf.shape[0]) ], ) - # In rare cases, `percentile_indices` equals to `sorted_idx.shape[0]` + # `percentile_indices` may be equal to `sorted_idx.shape[0]` due to floating + # point error (see #11813) max_idx = sorted_idx.shape[0] - 1 percentile_indices = xp.clip(percentile_indices, 0, max_idx) col_indices = xp.arange(array.shape[1], device=device) percentile_in_sorted = sorted_idx[percentile_indices, col_indices] - result = array[percentile_in_sorted, col_indices] + if average: + # From Hyndman and Fan (1996), `fraction_above` is `g` + fraction_above = ( + weight_cdf[col_indices, percentile_indices] - adjusted_percentile_rank + ) + # Alternatively, could use + # `is_exact_percentile = fraction_above <= xp.finfo(floating_dtype).eps` + # but that seems harder to read + is_fraction_above = fraction_above > xp.finfo(floating_dtype).eps + percentile_plus_one_indices = xp.clip(percentile_indices + 1, 0, max_idx) + percentile_plus_one_in_sorted = sorted_idx[ + percentile_plus_one_indices, col_indices + ] + # Handle case when when next index ('plus one') has sample weight of 0 + zero_weight_cols = col_indices[ + sample_weight[percentile_plus_one_in_sorted, col_indices] == 0 + ] + for col_idx in zero_weight_cols: + cdf_val = weight_cdf[col_idx, percentile_indices[col_idx]] + # Search for next index where `weighted_cdf` is greater + next_index = xp.searchsorted( + weight_cdf[col_idx, ...], cdf_val, side="right" + ) + # Handle case where there are trailing 0 sample weight samples + # and `percentile_indices` is already max index + if next_index >= max_idx: + # use original `percentile_indices` again + next_index = percentile_indices[col_idx] + + percentile_plus_one_in_sorted[col_idx] = sorted_idx[next_index, col_idx] + + result = xp.where( + is_fraction_above, + array[percentile_in_sorted, col_indices], + ( + array[percentile_in_sorted, col_indices] + + array[percentile_plus_one_in_sorted, col_indices] + ) + / 2, + ) + else: + result = array[percentile_in_sorted, col_indices] return result[0] if n_dim == 1 else result - - -# TODO: refactor to do the symmetrisation inside _weighted_percentile to avoid -# sorting the input array twice. -def _averaged_weighted_percentile(array, sample_weight, percentile_rank=50, xp=None): - return ( - _weighted_percentile(array, sample_weight, percentile_rank, xp=xp) - - _weighted_percentile(-array, sample_weight, 100 - percentile_rank, xp=xp) - ) / 2 diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index 1c979425f12f8..ce16bc7eef27e 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -12,121 +12,148 @@ from sklearn.utils._array_api import device as array_device from sklearn.utils.estimator_checks import _array_api_for_tests from sklearn.utils.fixes import np_version, parse_version -from sklearn.utils.stats import _averaged_weighted_percentile, _weighted_percentile +from sklearn.utils.stats import _weighted_percentile -def test_averaged_weighted_median(): - y = np.array([0, 1, 2, 3, 4, 5]) - sw = np.array([1, 1, 1, 1, 1, 1]) +@pytest.mark.parametrize("average", [True, False]) +@pytest.mark.parametrize("size", [10, 15]) +def test_weighted_percentile_matches_median(size, average): + """Ensure `_weighted_percentile` matches `median` when expected. - score = _averaged_weighted_percentile(y, sw, 50) + With unit `sample_weight`, `_weighted_percentile` should match median except + when `average=False` and the number of samples is odd. + When number of samples is odd, `_weighted_percentile(average=False)` always falls + on a single observation (not between 2 values, in which case the lower value would + be taken) and is thus equal to `np.median`. + For an even number of samples, `median` gives the average between the 2 middle + samples, `_weighted_percentile(average=False)` gives the higher (right) sample. + """ + y = np.arange(size) + sample_weight = np.ones_like(y) - assert score == np.median(y) + score = _weighted_percentile(y, sample_weight, 50, average=average) + # `_weighted_percentile(average=False)` does not match `median` when n is even + if size == 10 and average is False: + assert score != np.median(y) + else: + assert score == np.median(y) -def test_averaged_weighted_percentile(global_random_seed): - rng = np.random.RandomState(global_random_seed) - y = rng.randint(20, size=10) - sw = np.ones(10) +# test 2D? +@pytest.mark.parametrize("average", [True, False]) +@pytest.mark.parametrize("percentile_rank", [20, 35, 61]) +@pytest.mark.parametrize("size", [10, 15]) +def test_weighted_percentile_matches_numpy( + global_random_seed, size, percentile_rank, average +): + """Check `_weighted_percentile` with unit weights is correct. - score = _averaged_weighted_percentile(y, sw, 20) + `average=True` results should be the same as `np.percentile`'s + 'averaged_inverted_cdf'. + `average=False` results should be the same as `np.percentile`'s + 'inverted_cdf'. + Note `np.percentile` is the same as `np.quantile` except `q` is in range [0, 100]. - assert score == np.percentile(y, 20, method="averaged_inverted_cdf") + We parametrize through different `percentile_rank` and `size` to + ensure we get cases where `g=0` and `g>0` (see Hyndman and Fan 1996 for details). + """ + rng = np.random.RandomState(global_random_seed) + y = rng.randint(20, size=size) + sw = np.ones_like(y) + score = _weighted_percentile(y, sw, percentile_rank, average=average) -def test_averaged_and_weighted_percentile(): - y = np.array([0, 1, 2]) - sw = np.array([5, 1, 5]) - q = 50 + if average: + method = "averaged_inverted_cdf" + else: + method = "inverted_cdf" - score_averaged = _averaged_weighted_percentile(y, sw, q) - score = _weighted_percentile(y, sw, q) + assert score == np.percentile(y, percentile_rank, method=method) - assert score_averaged == score +@pytest.mark.parametrize("percentile_rank", [50, 100]) +def test_weighted_percentile_plus_one_clip_max(percentile_rank): + """Check `j+1` index is clipped to max, when `average=True`. -def test_weighted_percentile(): - """Check `weighted_percentile` on artificial data with obvious median.""" - y = np.empty(102, dtype=np.float64) - y[:50] = 0 - y[-51:] = 2 - y[-1] = 100000 - y[50] = 1 - sw = np.ones(102, dtype=np.float64) - sw[-1] = 0.0 - value = _weighted_percentile(y, sw, 50) - assert approx(value) == 1 + `percentile_plus_one_indices` can exceed max index when `percentile_indices` + is already at max index. + Note that when `g` (Hyndman and Fan) / `fraction_above` greater than 0, + `j+1` (Hyndman and Fan) / `percentile_plus_one_indices` is calculated but + never used (so it does not matter what this value is). + When `g=0` and `percentile_indices` is at max index, we perfectly at 100 + and take the average of 2x the max index. + """ + # Note for both spercentile_rank`s`,`percentile_indices` is already at max index + y = np.array([[0, 0], [1, 1]]) + sw = np.array([[0.1, 0.1], [2, 2]]) + score = _weighted_percentile(y, sw, percentile_rank) + for idx in range(2): + assert score[idx] == approx(1.0) def test_weighted_percentile_equal(): - """Check `weighted_percentile` with all weights equal to 1.""" - y = np.empty(102, dtype=np.float64) - y.fill(0.0) + """Check `weighted_percentile` with unit weights and all 0 values in `array`.""" + y = np.zeros(102, dtype=np.float64) sw = np.ones(102, dtype=np.float64) score = _weighted_percentile(y, sw, 50) assert approx(score) == 0 -def test_weighted_percentile_zero_weight(): - """Check `weighted_percentile` with all weights equal to 0.""" - y = np.empty(102, dtype=np.float64) - y.fill(1.0) - sw = np.ones(102, dtype=np.float64) - sw.fill(0.0) +def test_weighted_percentile_all_zero_weights(): + """Check `weighted_percentile` with all weights equal to 0 returns last index.""" + y = np.arange(10) + sw = np.zeros(10) value = _weighted_percentile(y, sw, 50) - assert approx(value) == 1.0 + assert approx(value) == 9.0 -def test_weighted_percentile_zero_weight_zero_percentile(): +@pytest.mark.parametrize("average", [True, False]) +@pytest.mark.parametrize("percentile_rank, expected_value", [(0, 2), (50, 3), (100, 5)]) +def test_weighted_percentile_ignores_zero_weight( + average, percentile_rank, expected_value +): """Check `weighted_percentile(percentile_rank=0)` behaves correctly. - Ensures that (leading)zero-weight observations ignored when `percentile_rank=0`. + Check that leading zero-weight observations ignored when `percentile_rank=0`. See #20528 for details. + Check that when `average=True` and the `j+1` ('plus one') index has sample weight + of 0, it is ignored. Also check that trailing zero weight observations are ignored + (e.g., when `percentile_rank=100`). """ - y = np.array([0, 1, 2, 3, 4, 5]) - sw = np.array([0, 0, 1, 1, 1, 0]) - value = _weighted_percentile(y, sw, 0) - assert approx(value) == 2 - - value = _weighted_percentile(y, sw, 50) - assert approx(value) == 3 - - value = _weighted_percentile(y, sw, 100) - assert approx(value) == 4 - - -def test_weighted_median_equal_weights(global_random_seed): - """Checks `_weighted_percentile(percentile_rank=50)` is the same as `np.median`. + y = np.array([0, 1, 2, 3, 4, 5, 6]) + sw = np.array([0, 0, 1, 1, 0, 1, 0]) - `sample_weights` are all 1s and the number of samples is odd. - When number of samples is odd, `_weighted_percentile` always falls on a single - observation (not between 2 values, in which case the lower value would be taken) - and is thus equal to `np.median`. - For an even number of samples, this check will not always hold as (note that - for some other percentile methods it will always hold). See #17370 for details. - """ - rng = np.random.RandomState(global_random_seed) - x = rng.randint(10, size=11) - weights = np.ones(x.shape) - median = np.median(x) - w_median = _weighted_percentile(x, weights) - assert median == approx(w_median) + value = _weighted_percentile( + np.vstack((y, y)).T, np.vstack((sw, sw)).T, percentile_rank, average=average + ) + for idx in range(2): + assert approx(value[idx]) == expected_value -def test_weighted_median_integer_weights(global_random_seed): - # Checks average weighted percentile_rank=0.5 is same as median when manually weight - # data +@pytest.mark.parametrize("average", [True, False]) +@pytest.mark.parametrize("percentile_rank", [20, 35, 61]) +def test_weighted_median_frequency_weights( + global_random_seed, percentile_rank, average +): + """Check integer weights give the same result as repeating values.""" rng = np.random.RandomState(global_random_seed) x = rng.randint(20, size=10) weights = rng.choice(5, size=10) - x_manual = np.repeat(x, weights) - median = np.median(x_manual) - w_median = _averaged_weighted_percentile(x, weights) - assert median == approx(w_median) + x_repeated = np.repeat(x, weights) + percentile_weights = _weighted_percentile( + x, weights, percentile_rank, average=average + ) + percentile_repeated = _weighted_percentile( + x_repeated, np.ones_like(x_repeated), percentile_rank, average=average + ) + assert percentile_weights == approx(percentile_repeated) -def test_weighted_percentile_2d(global_random_seed): + +@pytest.mark.parametrize("average", [True, False]) +def test_weighted_percentile_2d(global_random_seed, average): + """Check `_weighted_percentile` behaviour correct when `array` is 2D.""" # Check for when array 2D and sample_weight 1D rng = np.random.RandomState(global_random_seed) x1 = rng.randint(10, size=10) @@ -135,16 +162,21 @@ def test_weighted_percentile_2d(global_random_seed): x2 = rng.randint(20, size=10) x_2d = np.vstack((x1, x2)).T - w_median = _weighted_percentile(x_2d, w1) - p_axis_0 = [_weighted_percentile(x_2d[:, i], w1) for i in range(x_2d.shape[1])] + w_median = _weighted_percentile(x_2d, w1, average=average) + p_axis_0 = [ + _weighted_percentile(x_2d[:, i], w1, average=average) + for i in range(x_2d.shape[1]) + ] assert_allclose(w_median, p_axis_0) + # Check when array and sample_weight both 2D w2 = rng.choice(5, size=10) w_2d = np.vstack((w1, w2)).T - w_median = _weighted_percentile(x_2d, w_2d) + w_median = _weighted_percentile(x_2d, w_2d, average=average) p_axis_0 = [ - _weighted_percentile(x_2d[:, i], w_2d[:, i]) for i in range(x_2d.shape[1]) + _weighted_percentile(x_2d[:, i], w_2d[:, i], average=average) + for i in range(x_2d.shape[1]) ] assert_allclose(w_median, p_axis_0) @@ -234,12 +266,18 @@ def test_weighted_percentile_array_api_consistency( assert result_xp_np.dtype == np.float64 +@pytest.mark.parametrize("average", [True, False]) @pytest.mark.parametrize("sample_weight_ndim", [1, 2]) -def test_weighted_percentile_nan_filtered(sample_weight_ndim, global_random_seed): - """Test that calling _weighted_percentile on an array with nan values returns - the same results as calling _weighted_percentile on a filtered version of the data. +def test_weighted_percentile_nan_filtered( + global_random_seed, sample_weight_ndim, average +): + """Test `_weighted_percentile` ignores NaNs. + + Calling `_weighted_percentile` on an array with nan values returns the same + results as calling `_weighted_percentile` on a filtered version of the data. We test both with sample_weight of the same shape as the data and with - one-dimensional sample_weight.""" + one-dimensional sample_weight. + """ rng = np.random.RandomState(global_random_seed) array_with_nans = rng.rand(100, 10) @@ -252,7 +290,7 @@ def test_weighted_percentile_nan_filtered(sample_weight_ndim, global_random_seed sample_weight = rng.randint(1, 6, size=(100,)) # Find the weighted percentile on the array with nans: - results = _weighted_percentile(array_with_nans, sample_weight, 30) + results = _weighted_percentile(array_with_nans, sample_weight, 30, average=average) # Find the weighted percentile on the filtered array: filtered_array = [ @@ -269,7 +307,9 @@ def test_weighted_percentile_nan_filtered(sample_weight_ndim, global_random_seed expected_results = np.array( [ - _weighted_percentile(filtered_array[col], filtered_weights[col], 30) + _weighted_percentile( + filtered_array[col], filtered_weights[col], 30, average=average + ) for col in range(array_with_nans.shape[1]) ] ) @@ -307,8 +347,10 @@ def test_weighted_percentile_all_nan_column(): ) @pytest.mark.parametrize("percentile", [66, 10, 50]) def test_weighted_percentile_like_numpy_quantile(percentile, global_random_seed): - """Check that _weighted_percentile delivers equivalent results as np.quantile - with weights.""" + """Check `_weighted_percentile` equivalent to `np.quantile` with weights. + + Note currently only "inverted_cdf" method accepts weights. + """ rng = np.random.RandomState(global_random_seed) array = rng.rand(10, 100) @@ -330,9 +372,10 @@ def test_weighted_percentile_like_numpy_quantile(percentile, global_random_seed) ) @pytest.mark.parametrize("percentile", [66, 10, 50]) def test_weighted_percentile_like_numpy_nanquantile(percentile, global_random_seed): - """Check that _weighted_percentile delivers equivalent results as np.nanquantile - with weights.""" + """Check `_weighted_percentile` equivalent to `np.nanquantile` with weights. + Note currently only "inverted_cdf" method accepts weights. + """ rng = np.random.RandomState(global_random_seed) array_with_nans = rng.rand(10, 100) array_with_nans[rng.rand(*array_with_nans.shape) < 0.5] = np.nan