Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG] ENH: Add sample_weight to median_absolute_error #6217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
3 changes: 3 additions & 0 deletions doc/developers/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ Efficient Linear Algebra & Array Operations
- :func:`shuffle`: Shuffle arrays or sparse matrices in a consistent way.
Used in ``sklearn.cluster.k_means``.

- :func:`extmath.weighted_median`: an implementation to get weighted median
of the array using sample weights.


Efficient Random Sampling
=========================
Expand Down
15 changes: 13 additions & 2 deletions sklearn/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ..utils.validation import check_array, check_consistent_length
from ..utils.validation import column_or_1d
from ..utils.extmath import weighted_median
from ..externals.six import string_types

import warnings
Expand Down Expand Up @@ -241,7 +242,7 @@ def mean_squared_error(y_true, y_pred,
return np.average(output_errors, weights=multioutput)


def median_absolute_error(y_true, y_pred):
def median_absolute_error(y_true, y_pred, sample_weight=None):
"""Median absolute error regression loss

Read more in the :ref:`User Guide <median_absolute_error>`.
Expand All @@ -254,6 +255,9 @@ def median_absolute_error(y_true, y_pred):
y_pred : array-like of shape = (n_samples)
Estimated target values.

sample_weight : array-like of shape = (n_samples), optional
Sample weights.

Returns
-------
loss : float
Expand All @@ -272,7 +276,14 @@ def median_absolute_error(y_true, y_pred):
'uniform_average')
if y_type == 'continuous-multioutput':
raise ValueError("Multioutput not supported in median_absolute_error")
return np.median(np.abs(y_pred - y_true))
y_pred = y_pred.ravel()
y_true = y_true.ravel()
if sample_weight is None:
sample_weight = np.ones_like(y_true)
else:
check_consistent_length(y_pred, sample_weight)
sample_weight = np.asarray(sample_weight)
return weighted_median(np.abs(y_pred - y_true), sample_weight)


def explained_variance_score(y_true, y_pred,
Expand Down
11 changes: 6 additions & 5 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@
# matrix instead of a number. Testing of
# confusion_matrix with sample_weight is in
# test_classification.py
"median_absolute_error",
]


Expand Down Expand Up @@ -956,10 +955,12 @@ def check_sample_weight_invariance(name, metric, y1, y2):

# check that the weighted and unweighted scores are unequal
weighted_score = metric(y1, y2, sample_weight=sample_weight)
assert_not_equal(
unweighted_score, weighted_score,
msg="Unweighted and weighted scores are unexpectedly "
"equal (%f) for %s" % (weighted_score, name))
if name != "median_absolute_error":
# unweighted and weighted give same value for this metric. see #6217
assert_not_equal(
unweighted_score, weighted_score,
msg="Unweighted and weighted scores are unexpectedly "
"equal (%f) for %s" % (weighted_score, name))

# check that sample_weight can be a list
weighted_score_list = metric(y1, y2,
Expand Down
21 changes: 21 additions & 0 deletions sklearn/metrics/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_not_equal
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_array_almost_equal
Expand Down Expand Up @@ -141,3 +142,23 @@ def test_regression_custom_weights():
assert_almost_equal(maew, 0.475, decimal=3)
assert_almost_equal(rw, 0.94, decimal=2)
assert_almost_equal(evsw, 0.94, decimal=2)


def test_median_absolute_error_weights():
y_tr = [3, -0.5, 2, 7]
y_pr = [2.5, 0.0, 2, 8]
sample_weight = [2, 3, 1, 4]
# check that unit weights gives the same score as no weight
unweighted_score = median_absolute_error(y_tr, y_pr, sample_weight=None)
weighted_score = median_absolute_error(y_tr, y_pr,
sample_weight=np.ones(len(y_tr)))
assert_almost_equal(unweighted_score, weighted_score,
err_msg="For median_absolute_error sample_weight=None"
"is not equivalent to sample_weight=ones")

# check that the weighted and unweighted scores are unequal
weighted_score = median_absolute_error(y_tr, y_pr,
sample_weight=sample_weight)
assert_not_equal(unweighted_score, weighted_score,
msg="Unweighted and weighted scores are unexpectedly "
"equal (%f) for median_absolute_error" % weighted_score)
9 changes: 5 additions & 4 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,11 @@ def test_scorer_sample_weight():
sample_weight=sample_weight)
ignored = scorer(estimator[name], X_test[10:], target[10:])
unweighted = scorer(estimator[name], X_test, target)
assert_not_equal(weighted, unweighted,
msg="scorer {0} behaves identically when "
"called with sample weights: {1} vs "
"{2}".format(name, weighted, unweighted))
if "median_absolute_error" not in name:
assert_not_equal(weighted, unweighted,
msg="scorer {0} behaves identically when "
"called with sample weights: {1} vs "
"{2}".format(name, weighted, unweighted))
assert_almost_equal(weighted, ignored,
err_msg="scorer {0} behaves differently when "
"ignoring samples and setting sample_weight to"
Expand Down
39 changes: 39 additions & 0 deletions sklearn/utils/extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,3 +862,42 @@ def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
raise RuntimeError('cumsum was found to be unstable: '
'its last element does not correspond to sum')
return out


def weighted_median(array, sample_weight):
"""Compute the weighted median of the array with sample weight

Parameters
----------
array : array_like
n-dimensional array of which to find weighted median.
sample_weight : array_like
n-dimensional array of weights for each value

Returns
-------
weighted_median : float
Weighted median of the array
Example
-------
>>> from sklearn.utils.extmath import weighted_median
>>> import numpy as np
>>> weighted_median(np.array([1,2,3,4]),np.array([1,1,1,1]))
2.5
>>> weighted_median(np.array([1,2,3]),np.array([1,1,1]))
2.0

References
----------
.. [1] https://en.wikipedia.org/wiki/Percentile

"""
sorted_idx = np.argsort(array)
sample_weight = np.asarray(sample_weight)
sorted_sample_weight = sample_weight[sorted_idx]
weight_cdf = sorted_sample_weight.cumsum()
weighted_percentile = weight_cdf - sorted_sample_weight/2.0
weighted_percentile /= weight_cdf[-1]
sorted_array = array[sorted_idx]
weighted_median = np.interp(0.5, weighted_percentile, sorted_array)
return weighted_median
1 change: 0 additions & 1 deletion sklearn/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def _rankdata(a, method="average"):
def _weighted_percentile(array, sample_weight, percentile=50):
"""Compute the weighted ``percentile`` of ``array`` with ``sample_weight``. """
sorted_idx = np.argsort(array)

# Find index of median prediction for each sample
weight_cdf = sample_weight[sorted_idx].cumsum()
percentile_idx = np.searchsorted(
Expand Down
26 changes: 25 additions & 1 deletion sklearn/utils/tests/test_extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sklearn.utils.extmath import norm, squared_norm
from sklearn.utils.extmath import randomized_svd
from sklearn.utils.extmath import row_norms
from sklearn.utils.extmath import weighted_mode
from sklearn.utils.extmath import weighted_mode, weighted_median
from sklearn.utils.extmath import cartesian
from sklearn.utils.extmath import log_logistic
from sklearn.utils.extmath import fast_dot, _fast_dot
Expand Down Expand Up @@ -658,3 +658,27 @@ def test_stable_cumsum():
'cumsum was found to be unstable: its last element '
'does not correspond to sum',
stable_cumsum, r, rtol=0, atol=0)


def test_weighted_median():
rng = np.random.RandomState(0)
x = rng.randint(10, size=(10,))
weights = np.ones(x.shape)
median = np.median(x)
wmedian = weighted_median(x, weights)
assert_almost_equal(median, wmedian)


def test_weighted_median_equal_split():
rng = np.random.RandomState(0)
weights_left = rng.multinomial(20, [1/5.]*5, size=1)[0]
weights_right = rng.multinomial(20, [1/5.]*5, size=1)[0]
x = np.asarray(range(20))
rng.shuffle(x)
x = x[10:]
x.sort()
weights = np.hstack((weights_left, weights_right))
wmedian = weighted_median(x, weights)
sum_left = np.sum(weights[np.where(x < wmedian)])
sum_right = np.sum(weights[np.where(x > wmedian)])
assert_equal(sum_left, sum_right)