@@ -475,3 +475,47 @@ def test_cutoffclassifier_fit_params(objective_metric, fit_params_type):
475
475
classifier , objective_metric = objective_metric , objective_value = 0.5
476
476
)
477
477
model .fit (X , y , ** fit_params )
478
+
479
+
480
+ @pytest .mark .parametrize (
481
+ "objective_metric, objective_value" , [("tpr" , 0.5 ), ("tnr" , 0.5 )]
482
+ )
483
+ @pytest .mark .parametrize (
484
+ "response_method" , ["auto" , "decision_function" , "predict_proba" ]
485
+ )
486
+ def test_cutoffclassifier_response_method_scorer_tnr_tpr (
487
+ objective_metric , objective_value , response_method , global_random_seed
488
+ ):
489
+ """Check that we use the proper scorer and forwarding the requested response method
490
+ for `tnr` and `tpr`.
491
+ """
492
+ X , y = make_classification (n_samples = 100 , random_state = global_random_seed )
493
+ classifier = LogisticRegression ()
494
+
495
+ model = CutOffClassifier (
496
+ classifier ,
497
+ objective_metric = objective_metric ,
498
+ objective_value = objective_value ,
499
+ response_method = response_method ,
500
+ )
501
+ model .fit (X , y )
502
+
503
+ # Note that optimizing TPR will increase the decision threshold while optimizing
504
+ # TNR will decrease it. We therefore use the centered threshold (i.e. 0.5 for
505
+ # probabilities and 0.0 for decision function) to check that the decision threshold
506
+ # is properly set.
507
+ if response_method in ("auto" , "predict_proba" ):
508
+ # "auto" will fall back in priority on `predict_proba` if `estimator`
509
+ # supports it.
510
+ # we expect the decision threshold to be in [0, 1]
511
+ if objective_metric == "tpr" :
512
+ assert 0.5 < model .decision_threshold_ < 1
513
+ else : # "tnr"
514
+ assert 0 < model .decision_threshold_ < 0.5
515
+ else : # "decision_function"
516
+ # we expect the decision function to be centered in 0.0 and to be larger than
517
+ # -1 and 1.
518
+ if objective_metric == "tpr" :
519
+ assert 0 < model .decision_threshold_ < 20
520
+ else : # "tnr"
521
+ assert - 20 < model .decision_threshold_ < 0
0 commit comments