@@ -482,11 +482,11 @@ def roc_curve(y_true, y_score, pos_label=None):
482
482
483
483
if n_pos == 0 :
484
484
warnings .warn ("No positive samples in y_true, "
485
- "true positve value should be meaningless" )
485
+ "true positive value should be meaningless" )
486
486
n_pos = np .nan
487
487
if n_neg == 0 :
488
488
warnings .warn ("No negative samples in y_true, "
489
- "false positve value should be meaningless" )
489
+ "false positive value should be meaningless" )
490
490
n_neg = np .nan
491
491
492
492
thresholds = np .unique (y_score )
@@ -520,11 +520,28 @@ def roc_curve(y_true, y_score, pos_label=None):
520
520
tpr [- 1 ] = (sum_pos + current_pos_count ) / n_pos
521
521
fpr [- 1 ] = (sum_neg + current_neg_count ) / n_neg
522
522
523
- # hard decisions, add (0,0)
524
- if fpr .shape [0 ] == 2 :
523
+ thresholds = thresholds [::- 1 ]
524
+
525
+ if not (n_pos is np .nan or n_neg is np .nan ):
526
+ # add (0,0) and (1, 1)
527
+ if not (fpr [0 ] == 0 and fpr [- 1 ] == 1 ):
528
+ fpr = np .r_ [0. , fpr , 1. ]
529
+ tpr = np .r_ [0. , tpr , 1. ]
530
+ thresholds = np .r_ [thresholds [0 ] + 1 , thresholds ,
531
+ thresholds [- 1 ] - 1 ]
532
+ elif not fpr [0 ] == 0 :
533
+ fpr = np .r_ [0. , fpr ]
534
+ tpr = np .r_ [0. , tpr ]
535
+ thresholds = np .r_ [thresholds [0 ] + 1 , thresholds ]
536
+ elif not fpr [- 1 ] == 1 :
537
+ fpr = np .r_ [fpr , 1. ]
538
+ tpr = np .r_ [tpr , 1. ]
539
+ thresholds = np .r_ [thresholds , thresholds [- 1 ] - 1 ]
540
+ elif fpr .shape [0 ] == 2 :
541
+ # trivial decisions, add (0,0)
525
542
fpr = np .array ([0.0 , fpr [0 ], fpr [1 ]])
526
543
tpr = np .array ([0.0 , tpr [0 ], tpr [1 ]])
527
- # trivial decisions, add (0,0) and (1,1)
544
+ # trivial decisions, add (0,0) and (1,1)
528
545
elif fpr .shape [0 ] == 1 :
529
546
fpr = np .array ([0.0 , fpr [0 ], 1.0 ])
530
547
tpr = np .array ([0.0 , tpr [0 ], 1.0 ])
@@ -535,7 +552,7 @@ def roc_curve(y_true, y_score, pos_label=None):
535
552
if n_neg is np .nan :
536
553
fpr [0 ] = np .nan
537
554
538
- return fpr , tpr , thresholds [:: - 1 ]
555
+ return fpr , tpr , thresholds
539
556
540
557
541
558
###############################################################################
0 commit comments