@@ -2130,38 +2130,66 @@ def test_fit_and_score_working():
2130
2130
assert result ["parameters" ] == fit_and_score_kwargs ["parameters" ]
2131
2131
2132
2132
2133
+ class DataDependentFailingClassifier (BaseEstimator ):
2134
+ def __init__ (self , max_x_value = None ):
2135
+ self .max_x_value = max_x_value
2136
+
2137
+ def fit (self , X , y = None ):
2138
+ num_values_too_high = (X > self .max_x_value ).sum ()
2139
+ if num_values_too_high :
2140
+ raise ValueError (
2141
+ f"Classifier fit failed with { num_values_too_high } values too high"
2142
+ )
2143
+
2144
+ def score (self , X = None , Y = None ):
2145
+ return 0.0
2146
+
2147
+
2133
2148
@pytest .mark .parametrize ("error_score" , [np .nan , 0 ])
2134
- def test_cross_validate_failing_fits_warnings (error_score ):
2149
+ def test_cross_validate_some_failing_fits_warning (error_score ):
2135
2150
# Create a failing classifier to deliberately fail
2136
- failing_clf = FailingClassifier ( FailingClassifier . FAILING_PARAMETER )
2151
+ failing_clf = DataDependentFailingClassifier ( max_x_value = 8 )
2137
2152
# dummy X data
2138
2153
X = np .arange (1 , 10 )
2139
2154
y = np .ones (9 )
2140
- # fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None]
2141
2155
# passing error score to trigger the warning message
2142
2156
cross_validate_args = [failing_clf , X , y ]
2143
- cross_validate_kwargs = {"cv" : 7 , "error_score" : error_score }
2157
+ cross_validate_kwargs = {"cv" : 3 , "error_score" : error_score }
2144
2158
# check if the warning message type is as expected
2159
+
2160
+ individual_fit_error_message = (
2161
+ "ValueError: Classifier fit failed with 1 values too high"
2162
+ )
2145
2163
warning_message = re .compile (
2146
- "7 fits failed.+total of 7 .+The score on these"
2164
+ "2 fits failed.+total of 3 .+The score on these"
2147
2165
" train-test partitions for these parameters will be set to"
2148
- f" { cross_validate_kwargs ['error_score' ]} ." ,
2166
+ f" { cross_validate_kwargs ['error_score' ]} .+ { individual_fit_error_message } " ,
2149
2167
flags = re .DOTALL ,
2150
2168
)
2151
2169
2152
2170
with pytest .warns (FitFailedWarning , match = warning_message ):
2153
2171
cross_validate (* cross_validate_args , ** cross_validate_kwargs )
2154
2172
2155
- # since we're using FailingClassfier, our error will be the following
2156
- error_message = "ValueError: Failing classifier failed as required"
2157
2173
2158
- # check traceback is included
2159
- warning_message = re .compile (
2160
- "The score on these train-test partitions for these parameters will be set"
2161
- f" to { cross_validate_kwargs ['error_score' ]} .+{ error_message } " ,
2162
- re .DOTALL ,
2174
+ @pytest .mark .parametrize ("error_score" , [np .nan , 0 ])
2175
+ def test_cross_validate_all_failing_fits_error (error_score ):
2176
+ # Create a failing classifier to deliberately fail
2177
+ failing_clf = FailingClassifier (FailingClassifier .FAILING_PARAMETER )
2178
+ # dummy X data
2179
+ X = np .arange (1 , 10 )
2180
+ y = np .ones (9 )
2181
+
2182
+ cross_validate_args = [failing_clf , X , y ]
2183
+ cross_validate_kwargs = {"cv" : 7 , "error_score" : error_score }
2184
+
2185
+ individual_fit_error_message = "ValueError: Failing classifier failed as required"
2186
+ error_message = re .compile (
2187
+ "All the 7 fits failed.+your model is misconfigured.+"
2188
+ f"{ individual_fit_error_message } " ,
2189
+ flags = re .DOTALL ,
2163
2190
)
2164
- with pytest .warns (FitFailedWarning , match = warning_message ):
2191
+
2192
+ with pytest .raises (ValueError , match = error_message ):
2165
2193
cross_validate (* cross_validate_args , ** cross_validate_kwargs )
2166
2194
2167
2195
0 commit comments