-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Description
Describe the workflow you want to enable
We can currently pass predict_params and predict_proba_params to Pipelines, predictors, etc., at predict time when performing "manual" calls. When performing cross validation, however, there is no way to pass down any params to estimator to use when calling predict/predict_proba. I believe this to be a limitation of the library and think all cross-validation related code should allow passing down of said params to the estimator performing the prediction.
Describe your proposed solution
Given scoring is at the core of all cross validation methods within scikit-learn, I believe it will suffice to update the scorer parent classes to accept a new arg (maybe just predict_params given prediction probabilities is conceptually a subset of predicting in general, but I'd be open to better naming here if people have a preference, or even updating **predict_proba_params to **predict_params) and then work backwards from there.
Doing so, I can see that we'd need to update (new level of nesting indicates methods/functions that call the parent they're nested under):
_BaseScorer.__call__and_BaseScorer._score_score_rfe_single_fit- nothing needed from here as used inRFECVandRFEdoesn't havepredict_paramsorpredict_proba_params_incremental_fit_estimatorlearning_curve- generic function to be used on estimator --> needs updating. not used anywhere
_fit_and_scoreevaluate_candidates- created inside ofBaseSearchCV.fit--> add to fit signaturecross_validate- generic function to be used on estiamtor --> needs updating.cross_val_score- generic function to be used on estiamtor --> needs updating.GraphicalLassoCV.fit- nopredict_paramsto be used
learning_curve- already updated from abovevalidation_curve- generic function to be used on estiamtor --> needs updating. not used anywhere
As for how the arg will be added to method/function signatures, I think in general it depends on what the method/function is doing, e.g. use **predict_params in situations where only predicting is occuring, but use predict_params in situations where other logic is also performed (e.g. fitting, scoring, etc). In the all the above methods/functions, I believe they all perform fitting and predicting (either directly or within a sub-call) so I think it'd just be predict_params everywhere. This seems similar to how fit_params is currently passed around.
Describe alternatives you've considered, if relevant
None
Additional context
Related to this comment
I'm happy to take on this work if maintainers are happy for it to be done (I've actually got a branch locally with the relevant code updated, but am yet to update/write new tests in case the proposal is rejected)