Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 84eaed3

Browse files
GaelVaroquauxlarsmans
authored andcommitted
BUG: highly-degenerate roc curves
Fixes 1658 Make sure that roc_curves start at [0, 0] and end at [1, 1]
1 parent b52269f commit 84eaed3

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

sklearn/metrics/metrics.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -482,11 +482,11 @@ def roc_curve(y_true, y_score, pos_label=None):
482482

483483
if n_pos == 0:
484484
warnings.warn("No positive samples in y_true, "
485-
"true positve value should be meaningless")
485+
"true positive value should be meaningless")
486486
n_pos = np.nan
487487
if n_neg == 0:
488488
warnings.warn("No negative samples in y_true, "
489-
"false positve value should be meaningless")
489+
"false positive value should be meaningless")
490490
n_neg = np.nan
491491

492492
thresholds = np.unique(y_score)
@@ -520,11 +520,28 @@ def roc_curve(y_true, y_score, pos_label=None):
520520
tpr[-1] = (sum_pos + current_pos_count) / n_pos
521521
fpr[-1] = (sum_neg + current_neg_count) / n_neg
522522

523-
# hard decisions, add (0,0)
524-
if fpr.shape[0] == 2:
523+
thresholds = thresholds[::-1]
524+
525+
if not (n_pos is np.nan or n_neg is np.nan):
526+
# add (0,0) and (1, 1)
527+
if not (fpr[0] == 0 and fpr[-1] == 1):
528+
fpr = np.r_[0., fpr, 1.]
529+
tpr = np.r_[0., tpr, 1.]
530+
thresholds = np.r_[thresholds[0] + 1, thresholds,
531+
thresholds[-1] - 1]
532+
elif not fpr[0] == 0:
533+
fpr = np.r_[0., fpr]
534+
tpr = np.r_[0., tpr]
535+
thresholds = np.r_[thresholds[0] + 1, thresholds]
536+
elif not fpr[-1] == 1:
537+
fpr = np.r_[fpr, 1.]
538+
tpr = np.r_[tpr, 1.]
539+
thresholds = np.r_[thresholds, thresholds[-1] - 1]
540+
elif fpr.shape[0] == 2:
541+
# trivial decisions, add (0,0)
525542
fpr = np.array([0.0, fpr[0], fpr[1]])
526543
tpr = np.array([0.0, tpr[0], tpr[1]])
527-
# trivial decisions, add (0,0) and (1,1)
544+
# trivial decisions, add (0,0) and (1,1)
528545
elif fpr.shape[0] == 1:
529546
fpr = np.array([0.0, fpr[0], 1.0])
530547
tpr = np.array([0.0, tpr[0], 1.0])
@@ -535,7 +552,7 @@ def roc_curve(y_true, y_score, pos_label=None):
535552
if n_neg is np.nan:
536553
fpr[0] = np.nan
537554

538-
return fpr, tpr, thresholds[::-1]
555+
return fpr, tpr, thresholds
539556

540557

541558
###############################################################################

sklearn/metrics/tests/test_metrics.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,17 @@ def test_roc_curve():
9393
assert_almost_equal(roc_auc, auc_score(y_true, probas_pred))
9494

9595

96+
def test_roc_curve_end_points():
97+
# Make sure that roc_curve returns a curve start at 0 and ending and
98+
# 1 even in corner cases
99+
rng = np.random.RandomState(0)
100+
y_true = np.array([0] * 50 + [1] * 50)
101+
y_pred = rng.randint(3, size=100)
102+
fpr, tpr, thr = roc_curve(y_true, y_pred)
103+
assert_equal(fpr[0], 0)
104+
assert_equal(fpr[-1], 1)
105+
106+
96107
def test_roc_returns_consistency():
97108
"""Test whether the returned threshold matches up with tpr"""
98109
# make small toy dataset
@@ -101,8 +112,8 @@ def test_roc_returns_consistency():
101112

102113
# use the given thresholds to determine the tpr
103114
tpr_correct = []
104-
for t in range(len(thresholds)):
105-
tp = np.sum((probas_pred >= thresholds[t]) & y_true)
115+
for t in thresholds:
116+
tp = np.sum((probas_pred >= t) & y_true)
106117
p = np.sum(y_true)
107118
tpr_correct.append(1.0 * tp / p)
108119

0 commit comments

Comments
 (0)