Support usage of predict_params
and predict_proba_params
in cross validation
#24507
Labels
predict_params
and predict_proba_params
in cross validation
#24507
Describe the workflow you want to enable
We can currently pass
predict_params
andpredict_proba_params
toPipeline
s, predictors, etc., at predict time when performing "manual" calls. When performing cross validation, however, there is no way to pass down any params toestimator
to use when callingpredict
/predict_proba
. I believe this to be a limitation of the library and think all cross-validation related code should allow passing down of said params to the estimator performing the prediction.Describe your proposed solution
Given scoring is at the core of all cross validation methods within scikit-learn, I believe it will suffice to update the scorer parent classes to accept a new arg (maybe just
predict_params
given prediction probabilities is conceptually a subset of predicting in general, but I'd be open to better naming here if people have a preference, or even updating**predict_proba_params
to**predict_params
) and then work backwards from there.Doing so, I can see that we'd need to update (new level of nesting indicates methods/functions that call the parent they're nested under):
_BaseScorer.__call__
and_BaseScorer._score
_score
_rfe_single_fit
- nothing needed from here as used inRFECV
andRFE
doesn't havepredict_params
orpredict_proba_params
_incremental_fit_estimator
learning_curve
- generic function to be used on estimator --> needs updating. not used anywhere_fit_and_score
evaluate_candidates
- created inside ofBaseSearchCV.fit
--> add to fit signaturecross_validate
- generic function to be used on estiamtor --> needs updating.cross_val_score
- generic function to be used on estiamtor --> needs updating.GraphicalLassoCV.fit
- nopredict_params
to be usedlearning_curve
- already updated from abovevalidation_curve
- generic function to be used on estiamtor --> needs updating. not used anywhereAs for how the arg will be added to method/function signatures, I think in general it depends on what the method/function is doing, e.g. use
**predict_params
in situations where only predicting is occuring, but usepredict_params
in situations where other logic is also performed (e.g. fitting, scoring, etc). In the all the above methods/functions, I believe they all perform fitting and predicting (either directly or within a sub-call) so I think it'd just bepredict_params
everywhere. This seems similar to howfit_params
is currently passed around.Describe alternatives you've considered, if relevant
None
Additional context
Related to this comment
I'm happy to take on this work if maintainers are happy for it to be done (I've actually got a branch locally with the relevant code updated, but am yet to update/write new tests in case the proposal is rejected)
The text was updated successfully, but these errors were encountered: