@@ -2624,62 +2624,37 @@ def test_log_loss():
2624
2624
)
2625
2625
loss = log_loss (y_true , y_pred )
2626
2626
loss_true = - np .mean (bernoulli .logpmf (np .array (y_true ) == "yes" , y_pred [:, 1 ]))
2627
- assert_almost_equal (loss , loss_true )
2627
+ assert_allclose (loss , loss_true )
2628
2628
2629
2629
# multiclass case; adapted from http://bit.ly/RJJHWA
2630
2630
y_true = [1 , 0 , 2 ]
2631
2631
y_pred = [[0.2 , 0.7 , 0.1 ], [0.6 , 0.2 , 0.2 ], [0.6 , 0.1 , 0.3 ]]
2632
2632
loss = log_loss (y_true , y_pred , normalize = True )
2633
- assert_almost_equal (loss , 0.6904911 )
2633
+ assert_allclose (loss , 0.6904911 )
2634
2634
2635
2635
# check that we got all the shapes and axes right
2636
2636
# by doubling the length of y_true and y_pred
2637
2637
y_true *= 2
2638
2638
y_pred *= 2
2639
2639
loss = log_loss (y_true , y_pred , normalize = False )
2640
- assert_almost_equal (loss , 0.6904911 * 6 , decimal = 6 )
2641
-
2642
- user_warning_msg = "y_pred values do not sum to one"
2643
- # check eps and handling of absolute zero and one probabilities
2644
- y_pred = np .asarray (y_pred ) > 0.5
2645
- with pytest .warns (FutureWarning ):
2646
- loss = log_loss (y_true , y_pred , normalize = True , eps = 0.1 )
2647
- with pytest .warns (UserWarning , match = user_warning_msg ):
2648
- assert_almost_equal (loss , log_loss (y_true , np .clip (y_pred , 0.1 , 0.9 )))
2649
-
2650
- # binary case: check correct boundary values for eps = 0
2651
- with pytest .warns (FutureWarning ):
2652
- assert log_loss ([0 , 1 ], [0 , 1 ], eps = 0 ) == 0
2653
- with pytest .warns (FutureWarning ):
2654
- assert log_loss ([0 , 1 ], [0 , 0 ], eps = 0 ) == np .inf
2655
- with pytest .warns (FutureWarning ):
2656
- assert log_loss ([0 , 1 ], [1 , 1 ], eps = 0 ) == np .inf
2657
-
2658
- # multiclass case: check correct boundary values for eps = 0
2659
- with pytest .warns (FutureWarning ):
2660
- assert log_loss ([0 , 1 , 2 ], [[1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ]], eps = 0 ) == 0
2661
- with pytest .warns (FutureWarning ):
2662
- assert (
2663
- log_loss ([0 , 1 , 2 ], [[0 , 0.5 , 0.5 ], [0 , 1 , 0 ], [0 , 0 , 1 ]], eps = 0 ) == np .inf
2664
- )
2640
+ assert_allclose (loss , 0.6904911 * 6 )
2665
2641
2666
2642
# raise error if number of classes are not equal.
2667
2643
y_true = [1 , 0 , 2 ]
2668
- y_pred = [[0.2 , 0.7 ], [0.6 , 0.5 ], [0.4 , 0.1 ]]
2644
+ y_pred = [[0.3 , 0.7 ], [0.6 , 0.4 ], [0.4 , 0.6 ]]
2669
2645
with pytest .raises (ValueError ):
2670
2646
log_loss (y_true , y_pred )
2671
2647
2672
2648
# case when y_true is a string array object
2673
2649
y_true = ["ham" , "spam" , "spam" , "ham" ]
2674
- y_pred = [[0.2 , 0.7 ], [0.6 , 0.5 ], [0.4 , 0.1 ], [0.7 , 0.2 ]]
2675
- with pytest .warns (UserWarning , match = user_warning_msg ):
2676
- loss = log_loss (y_true , y_pred )
2677
- assert_almost_equal (loss , 1.0383217 , decimal = 6 )
2650
+ y_pred = [[0.3 , 0.7 ], [0.6 , 0.4 ], [0.4 , 0.6 ], [0.7 , 0.3 ]]
2651
+ loss = log_loss (y_true , y_pred )
2652
+ assert_allclose (loss , 0.7469410 )
2678
2653
2679
2654
# test labels option
2680
2655
2681
2656
y_true = [2 , 2 ]
2682
- y_pred = [[0.2 , 0.7 ], [0.6 , 0.5 ]]
2657
+ y_pred = [[0.2 , 0.8 ], [0.6 , 0.4 ]]
2683
2658
y_score = np .array ([[0.1 , 0.9 ], [0.1 , 0.9 ]])
2684
2659
error_str = (
2685
2660
r"y_true contains only one label \(2\). Please provide "
@@ -2688,50 +2663,66 @@ def test_log_loss():
2688
2663
with pytest .raises (ValueError , match = error_str ):
2689
2664
log_loss (y_true , y_pred )
2690
2665
2691
- y_pred = [[0.2 , 0.7 ], [0.6 , 0.5 ], [0.2 , 0.3 ]]
2692
- error_str = "Found input variables with inconsistent numbers of samples: [3, 2]"
2693
- (ValueError , error_str , log_loss , y_true , y_pred )
2666
+ y_pred = [[0.2 , 0.8 ], [0.6 , 0.4 ], [0.7 , 0.3 ]]
2667
+ error_str = r"Found input variables with inconsistent numbers of samples: \[3, 2\]"
2668
+ with pytest .raises (ValueError , match = error_str ):
2669
+ log_loss (y_true , y_pred )
2694
2670
2695
2671
# works when the labels argument is used
2696
2672
2697
2673
true_log_loss = - np .mean (np .log (y_score [:, 1 ]))
2698
2674
calculated_log_loss = log_loss (y_true , y_score , labels = [1 , 2 ])
2699
- assert_almost_equal (calculated_log_loss , true_log_loss )
2675
+ assert_allclose (calculated_log_loss , true_log_loss )
2700
2676
2701
2677
# ensure labels work when len(np.unique(y_true)) != y_pred.shape[1]
2702
2678
y_true = [1 , 2 , 2 ]
2703
- y_score2 = [[0.2 , 0.7 , 0.3 ], [0.6 , 0.5 , 0.3 ], [0.3 , 0.9 , 0.1 ]]
2704
- with pytest . warns ( UserWarning , match = user_warning_msg ):
2705
- loss = log_loss ( y_true , y_score2 , labels = [ 1 , 2 , 3 ] )
2706
- assert_almost_equal ( loss , 1.0630345 , decimal = 6 )
2679
+ y_score2 = [[0.7 , 0.1 , 0.2 ], [0.2 , 0.7 , 0.1 ], [0.1 , 0.7 , 0.2 ]]
2680
+ loss = log_loss ( y_true , y_score2 , labels = [ 1 , 2 , 3 ])
2681
+ assert_allclose ( loss , - np . log ( 0.7 ) )
2682
+
2707
2683
2684
+ @pytest .mark .parametrize ("dtype" , [np .float64 , np .float32 , np .float16 ])
2685
+ def test_log_loss_eps (dtype ):
2686
+ """Check the behaviour internal eps that changes depending on the input dtype.
2708
2687
2709
- def test_log_loss_eps_auto (global_dtype ):
2710
- """Check the behaviour of `eps="auto"` that changes depending on the input
2711
- array dtype.
2712
2688
Non-regression test for:
2713
2689
https://github.com/scikit-learn/scikit-learn/issues/24315
2714
2690
"""
2715
- y_true = np .array ([0 , 1 ], dtype = global_dtype )
2716
- y_pred = y_true . copy ( )
2691
+ y_true = np .array ([0 , 1 ], dtype = dtype )
2692
+ y_pred = np . array ([ 1 , 0 ], dtype = dtype )
2717
2693
2718
- loss = log_loss (y_true , y_pred , eps = "auto" )
2694
+ loss = log_loss (y_true , y_pred )
2719
2695
assert np .isfinite (loss )
2720
2696
2721
2697
2722
- def test_log_loss_eps_auto_float16 ():
2723
- """Check the behaviour of `eps="auto"` for np.float16"""
2724
- y_true = np .array ([0 , 1 ], dtype = np .float16 )
2725
- y_pred = y_true .copy ()
2698
+ @pytest .mark .parametrize ("dtype" , [np .float64 , np .float32 , np .float16 ])
2699
+ def test_log_loss_not_probabilities_warning (dtype ):
2700
+ """Check that log_loss raises a warning when y_pred values don't sum to 1."""
2701
+ y_true = np .array ([0 , 1 , 1 , 0 ])
2702
+ y_pred = np .array ([[0.2 , 0.7 ], [0.6 , 0.3 ], [0.4 , 0.7 ], [0.8 , 0.3 ]], dtype = dtype )
2726
2703
2727
- loss = log_loss (y_true , y_pred , eps = "auto" )
2728
- assert np .isfinite (loss )
2704
+ with pytest .warns (UserWarning , match = "The y_pred values do not sum to one." ):
2705
+ log_loss (y_true , y_pred )
2706
+
2707
+
2708
+ @pytest .mark .parametrize (
2709
+ "y_true, y_pred" ,
2710
+ [
2711
+ ([0 , 1 , 0 ], [0 , 1 , 0 ]),
2712
+ ([0 , 1 , 0 ], [[1 , 0 ], [0 , 1 ], [1 , 0 ]]),
2713
+ ([0 , 1 , 2 ], [[1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ]]),
2714
+ ],
2715
+ )
2716
+ def test_log_loss_perfect_predictions (y_true , y_pred ):
2717
+ """Check that log_loss returns 0 for perfect predictions."""
2718
+ # Because of the clipping, the result is not exactly 0
2719
+ assert log_loss (y_true , y_pred ) == pytest .approx (0 )
2729
2720
2730
2721
2731
2722
def test_log_loss_pandas_input ():
2732
2723
# case when input is a pandas series and dataframe gh-5715
2733
2724
y_tr = np .array (["ham" , "spam" , "spam" , "ham" ])
2734
- y_pr = np .array ([[0.2 , 0.7 ], [0.6 , 0.5 ], [0.4 , 0.1 ], [0.7 , 0.2 ]])
2725
+ y_pr = np .array ([[0.3 , 0.7 ], [0.6 , 0.4 ], [0.4 , 0.6 ], [0.7 , 0.3 ]])
2735
2726
types = [(MockDataFrame , MockDataFrame )]
2736
2727
try :
2737
2728
from pandas import DataFrame , Series
@@ -2742,9 +2733,8 @@ def test_log_loss_pandas_input():
2742
2733
for TrueInputType , PredInputType in types :
2743
2734
# y_pred dataframe, y_true series
2744
2735
y_true , y_pred = TrueInputType (y_tr ), PredInputType (y_pr )
2745
- with pytest .warns (UserWarning , match = "y_pred values do not sum to one" ):
2746
- loss = log_loss (y_true , y_pred )
2747
- assert_almost_equal (loss , 1.0383217 , decimal = 6 )
2736
+ loss = log_loss (y_true , y_pred )
2737
+ assert_allclose (loss , 0.7469410 )
2748
2738
2749
2739
2750
2740
def test_brier_score_loss ():
0 commit comments