diff --git a/doc/developers/utilities.rst b/doc/developers/utilities.rst index 8dbe460635926..895c4b8d23506 100644 --- a/doc/developers/utilities.rst +++ b/doc/developers/utilities.rst @@ -122,6 +122,9 @@ Efficient Linear Algebra & Array Operations - :func:`shuffle`: Shuffle arrays or sparse matrices in a consistent way. Used in ``sklearn.cluster.k_means``. +- :func:`extmath.weighted_median`: an implementation to get weighted median + of the array using sample weights. + Efficient Random Sampling ========================= diff --git a/sklearn/metrics/regression.py b/sklearn/metrics/regression.py index af3a02d6f33f9..880403623dae7 100644 --- a/sklearn/metrics/regression.py +++ b/sklearn/metrics/regression.py @@ -26,6 +26,7 @@ from ..utils.validation import check_array, check_consistent_length from ..utils.validation import column_or_1d +from ..utils.extmath import weighted_median from ..externals.six import string_types import warnings @@ -241,7 +242,7 @@ def mean_squared_error(y_true, y_pred, return np.average(output_errors, weights=multioutput) -def median_absolute_error(y_true, y_pred): +def median_absolute_error(y_true, y_pred, sample_weight=None): """Median absolute error regression loss Read more in the :ref:`User Guide `. @@ -254,6 +255,9 @@ def median_absolute_error(y_true, y_pred): y_pred : array-like of shape = (n_samples) Estimated target values. + sample_weight : array-like of shape = (n_samples), optional + Sample weights. + Returns ------- loss : float @@ -272,7 +276,14 @@ def median_absolute_error(y_true, y_pred): 'uniform_average') if y_type == 'continuous-multioutput': raise ValueError("Multioutput not supported in median_absolute_error") - return np.median(np.abs(y_pred - y_true)) + y_pred = y_pred.ravel() + y_true = y_true.ravel() + if sample_weight is None: + sample_weight = np.ones_like(y_true) + else: + check_consistent_length(y_pred, sample_weight) + sample_weight = np.asarray(sample_weight) + return weighted_median(np.abs(y_pred - y_true), sample_weight) def explained_variance_score(y_true, y_pred, diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index fa4c7e8d3124b..c7fe05e615f38 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -374,7 +374,6 @@ # matrix instead of a number. Testing of # confusion_matrix with sample_weight is in # test_classification.py - "median_absolute_error", ] @@ -956,10 +955,12 @@ def check_sample_weight_invariance(name, metric, y1, y2): # check that the weighted and unweighted scores are unequal weighted_score = metric(y1, y2, sample_weight=sample_weight) - assert_not_equal( - unweighted_score, weighted_score, - msg="Unweighted and weighted scores are unexpectedly " - "equal (%f) for %s" % (weighted_score, name)) + if name != "median_absolute_error": + # unweighted and weighted give same value for this metric. see #6217 + assert_not_equal( + unweighted_score, weighted_score, + msg="Unweighted and weighted scores are unexpectedly " + "equal (%f) for %s" % (weighted_score, name)) # check that sample_weight can be a list weighted_score_list = metric(y1, y2, diff --git a/sklearn/metrics/tests/test_regression.py b/sklearn/metrics/tests/test_regression.py index 600bcc135a202..6eec266e579cc 100644 --- a/sklearn/metrics/tests/test_regression.py +++ b/sklearn/metrics/tests/test_regression.py @@ -5,6 +5,7 @@ from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_not_equal from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal @@ -141,3 +142,23 @@ def test_regression_custom_weights(): assert_almost_equal(maew, 0.475, decimal=3) assert_almost_equal(rw, 0.94, decimal=2) assert_almost_equal(evsw, 0.94, decimal=2) + + +def test_median_absolute_error_weights(): + y_tr = [3, -0.5, 2, 7] + y_pr = [2.5, 0.0, 2, 8] + sample_weight = [2, 3, 1, 4] + # check that unit weights gives the same score as no weight + unweighted_score = median_absolute_error(y_tr, y_pr, sample_weight=None) + weighted_score = median_absolute_error(y_tr, y_pr, + sample_weight=np.ones(len(y_tr))) + assert_almost_equal(unweighted_score, weighted_score, + err_msg="For median_absolute_error sample_weight=None" + "is not equivalent to sample_weight=ones") + + # check that the weighted and unweighted scores are unequal + weighted_score = median_absolute_error(y_tr, y_pr, + sample_weight=sample_weight) + assert_not_equal(unweighted_score, weighted_score, + msg="Unweighted and weighted scores are unexpectedly " + "equal (%f) for median_absolute_error" % weighted_score) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index c3643e81031a5..ce2ccb39b6e55 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -384,10 +384,11 @@ def test_scorer_sample_weight(): sample_weight=sample_weight) ignored = scorer(estimator[name], X_test[10:], target[10:]) unweighted = scorer(estimator[name], X_test, target) - assert_not_equal(weighted, unweighted, - msg="scorer {0} behaves identically when " - "called with sample weights: {1} vs " - "{2}".format(name, weighted, unweighted)) + if "median_absolute_error" not in name: + assert_not_equal(weighted, unweighted, + msg="scorer {0} behaves identically when " + "called with sample weights: {1} vs " + "{2}".format(name, weighted, unweighted)) assert_almost_equal(weighted, ignored, err_msg="scorer {0} behaves differently when " "ignoring samples and setting sample_weight to" diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 1857a27adfadc..d9ed0c669ea44 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -862,3 +862,42 @@ def stable_cumsum(arr, rtol=1e-05, atol=1e-08): raise RuntimeError('cumsum was found to be unstable: ' 'its last element does not correspond to sum') return out + + +def weighted_median(array, sample_weight): + """Compute the weighted median of the array with sample weight + + Parameters + ---------- + array : array_like + n-dimensional array of which to find weighted median. + sample_weight : array_like + n-dimensional array of weights for each value + + Returns + ------- + weighted_median : float + Weighted median of the array + Example + ------- + >>> from sklearn.utils.extmath import weighted_median + >>> import numpy as np + >>> weighted_median(np.array([1,2,3,4]),np.array([1,1,1,1])) + 2.5 + >>> weighted_median(np.array([1,2,3]),np.array([1,1,1])) + 2.0 + + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Percentile + + """ + sorted_idx = np.argsort(array) + sample_weight = np.asarray(sample_weight) + sorted_sample_weight = sample_weight[sorted_idx] + weight_cdf = sorted_sample_weight.cumsum() + weighted_percentile = weight_cdf - sorted_sample_weight/2.0 + weighted_percentile /= weight_cdf[-1] + sorted_array = array[sorted_idx] + weighted_median = np.interp(0.5, weighted_percentile, sorted_array) + return weighted_median diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 463146d038c6b..7d14ac7933a2b 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -51,7 +51,6 @@ def _rankdata(a, method="average"): def _weighted_percentile(array, sample_weight, percentile=50): """Compute the weighted ``percentile`` of ``array`` with ``sample_weight``. """ sorted_idx = np.argsort(array) - # Find index of median prediction for each sample weight_cdf = sample_weight[sorted_idx].cumsum() percentile_idx = np.searchsorted( diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 55f96cdf1574c..1ecc1809a8020 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -27,7 +27,7 @@ from sklearn.utils.extmath import norm, squared_norm from sklearn.utils.extmath import randomized_svd from sklearn.utils.extmath import row_norms -from sklearn.utils.extmath import weighted_mode +from sklearn.utils.extmath import weighted_mode, weighted_median from sklearn.utils.extmath import cartesian from sklearn.utils.extmath import log_logistic from sklearn.utils.extmath import fast_dot, _fast_dot @@ -658,3 +658,27 @@ def test_stable_cumsum(): 'cumsum was found to be unstable: its last element ' 'does not correspond to sum', stable_cumsum, r, rtol=0, atol=0) + + +def test_weighted_median(): + rng = np.random.RandomState(0) + x = rng.randint(10, size=(10,)) + weights = np.ones(x.shape) + median = np.median(x) + wmedian = weighted_median(x, weights) + assert_almost_equal(median, wmedian) + + +def test_weighted_median_equal_split(): + rng = np.random.RandomState(0) + weights_left = rng.multinomial(20, [1/5.]*5, size=1)[0] + weights_right = rng.multinomial(20, [1/5.]*5, size=1)[0] + x = np.asarray(range(20)) + rng.shuffle(x) + x = x[10:] + x.sort() + weights = np.hstack((weights_left, weights_right)) + wmedian = weighted_median(x, weights) + sum_left = np.sum(weights[np.where(x < wmedian)]) + sum_right = np.sum(weights[np.where(x > wmedian)]) + assert_equal(sum_left, sum_right)