Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a12e7ad

Browse files
committed
Some more changes and tests
1 parent 5855c6f commit a12e7ad

File tree

4 files changed

+39
-17
lines changed

4 files changed

+39
-17
lines changed

sklearn/metrics/regression.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from ..utils.validation import check_array, check_consistent_length
2828
from ..utils.validation import column_or_1d
29+
from ..utils.extmath import weighted_median
2930
from ..externals.six import string_types
3031

3132
import warnings
@@ -241,16 +242,6 @@ def mean_squared_error(y_true, y_pred,
241242
return np.average(output_errors, weights=multioutput)
242243

243244

244-
def _weighted_percentile(array, sample_weight, percentile=50):
245-
sorted_idx = np.argsort(array)
246-
sample_weight = np.array(sample_weight)
247-
weight_cdf = sample_weight[sorted_idx].cumsum()
248-
weighted_percentile = (weight_cdf - sample_weight/2.0) / weight_cdf[-1]
249-
sorted_array = np.sort(array)
250-
weighted_median = np.interp(percentile/100., weighted_percentile, sorted_array)
251-
return weighted_median
252-
253-
254245
def median_absolute_error(y_true, y_pred, sample_weight=None):
255246
"""Median absolute error regression loss
256247
@@ -285,15 +276,14 @@ def median_absolute_error(y_true, y_pred, sample_weight=None):
285276
'uniform_average')
286277
if y_type == 'continuous-multioutput':
287278
raise ValueError("Multioutput not supported in median_absolute_error")
288-
289279
if sample_weight is None:
290280
return np.median(np.abs(y_pred - y_true))
291281
else:
292282
check_consistent_length(y_pred, sample_weight)
293283
sample_weight = np.array(sample_weight)
294284
y_pred = y_pred.ravel()
295285
y_true = y_true.ravel()
296-
return _weighted_percentile(np.abs(y_pred - y_true),
286+
return weighted_median(np.abs(y_pred - y_true),
297287
np.asarray(sample_weight))
298288

299289

sklearn/metrics/tests/test_common.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,6 @@
361361
METRICS_WITHOUT_SAMPLE_WEIGHT = [
362362
"cohen_kappa_score",
363363
"confusion_matrix",
364-
"median_absolute_error",
365364
]
366365

367366

@@ -919,10 +918,12 @@ def check_sample_weight_invariance(name, metric, y1, y2):
919918

920919
# check that the weighted and unweighted scores are unequal
921920
weighted_score = metric(y1, y2, sample_weight=sample_weight)
922-
assert_not_equal(
923-
unweighted_score, weighted_score,
924-
msg="Unweighted and weighted scores are unexpectedly "
925-
"equal (%f) for %s" % (weighted_score, name))
921+
if name != "median_absolute_error":
922+
# unweighted and weighted give same value for this metric. see #6217
923+
assert_not_equal(
924+
unweighted_score, weighted_score,
925+
msg="Unweighted and weighted scores are unexpectedly "
926+
"equal (%f) for %s" % (weighted_score, name))
926927

927928
# check that sample_weight can be a list
928929
weighted_score_list = metric(y1, y2,

sklearn/metrics/tests/test_regression.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from sklearn.utils.testing import assert_raises
77
from sklearn.utils.testing import assert_equal
8+
from sklearn.utils.testing import assert_not_equal
89
from sklearn.utils.testing import assert_almost_equal
910
from sklearn.utils.testing import assert_array_equal
1011
from sklearn.utils.testing import assert_array_almost_equal
@@ -142,3 +143,23 @@ def test_regression_custom_weights():
142143
assert_almost_equal(maew, 0.475, decimal=3)
143144
assert_almost_equal(rw, 0.94, decimal=2)
144145
assert_almost_equal(evsw, 0.94, decimal=2)
146+
147+
def test_median_absolute_error_weights():
148+
y_tr = [3, -0.5, 2, 7]
149+
y_pr = [2.5, 0.0, 2, 8]
150+
sample_weight = [1, 2, 3, 4]
151+
# check that unit weights gives the same score as no weight
152+
unweighted_score = median_absolute_error(y_tr, y_pr, sample_weight=None)
153+
assert_almost_equal(
154+
unweighted_score, median_absolute_error(y_tr, y_pr,
155+
sample_weight=np.ones(shape=len(y_tr))),
156+
err_msg="For median_absolute_error sample_weight=None is not "
157+
"equivalent to sample_weight=ones" )
158+
159+
# check that the weighted and unweighted scores are unequal
160+
weighted_score = median_absolute_error(y_tr, y_pr,
161+
sample_weight=sample_weight)
162+
assert_not_equal(
163+
unweighted_score, weighted_score,
164+
msg="Unweighted and weighted scores are unexpectedly "
165+
"equal (%f) for median_absolute_error" % weighted_score)

sklearn/utils/extmath.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,3 +828,13 @@ def _deterministic_vector_sign_flip(u):
828828
signs = np.sign(u[range(u.shape[0]), max_abs_rows])
829829
u *= signs[:, np.newaxis]
830830
return u
831+
832+
833+
def weighted_median(array, sample_weight):
834+
sorted_idx = np.argsort(array)
835+
sample_weight = np.asarray(sample_weight)
836+
weight_cdf = sample_weight[sorted_idx].cumsum()
837+
weighted_percentile = (weight_cdf - sample_weight/2.0) / weight_cdf[-1]
838+
sorted_array = array[sorted_idx]
839+
weighted_median = np.interp(0.5, weighted_percentile, sorted_array)
840+
return weighted_median

0 commit comments

Comments
 (0)