Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Support usage of predict_params and predict_proba_params in cross validation #24507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
hsorsky opened this issue Sep 24, 2022 · 2 comments · Fixed by #26789
Closed

Support usage of predict_params and predict_proba_params in cross validation #24507

hsorsky opened this issue Sep 24, 2022 · 2 comments · Fixed by #26789

Comments

@hsorsky
Copy link
Contributor

hsorsky commented Sep 24, 2022

Describe the workflow you want to enable

We can currently pass predict_params and predict_proba_params to Pipelines, predictors, etc., at predict time when performing "manual" calls. When performing cross validation, however, there is no way to pass down any params to estimator to use when calling predict/predict_proba. I believe this to be a limitation of the library and think all cross-validation related code should allow passing down of said params to the estimator performing the prediction.

Describe your proposed solution

Given scoring is at the core of all cross validation methods within scikit-learn, I believe it will suffice to update the scorer parent classes to accept a new arg (maybe just predict_params given prediction probabilities is conceptually a subset of predicting in general, but I'd be open to better naming here if people have a preference, or even updating **predict_proba_params to **predict_params) and then work backwards from there.

Doing so, I can see that we'd need to update (new level of nesting indicates methods/functions that call the parent they're nested under):

  • _BaseScorer.__call__ and _BaseScorer._score
    • _score
      • _rfe_single_fit - nothing needed from here as used in RFECV and RFE doesn't have predict_params or predict_proba_params
      • _incremental_fit_estimator
        • learning_curve - generic function to be used on estimator --> needs updating. not used anywhere
      • _fit_and_score
        • evaluate_candidates - created inside of BaseSearchCV.fit --> add to fit signature
        • cross_validate - generic function to be used on estiamtor --> needs updating.
          • cross_val_score - generic function to be used on estiamtor --> needs updating.
            • GraphicalLassoCV.fit - no predict_params to be used
        • learning_curve - already updated from above
        • validation_curve - generic function to be used on estiamtor --> needs updating. not used anywhere

As for how the arg will be added to method/function signatures, I think in general it depends on what the method/function is doing, e.g. use **predict_params in situations where only predicting is occuring, but use predict_params in situations where other logic is also performed (e.g. fitting, scoring, etc). In the all the above methods/functions, I believe they all perform fitting and predicting (either directly or within a sub-call) so I think it'd just be predict_params everywhere. This seems similar to how fit_params is currently passed around.

Describe alternatives you've considered, if relevant

None

Additional context

Related to this comment

I'm happy to take on this work if maintainers are happy for it to be done (I've actually got a branch locally with the relevant code updated, but am yet to update/write new tests in case the proposal is rejected)

@adrinjalali
Copy link
Member

#24027 will fix this.

@adrinjalali adrinjalali removed the Needs Triage Issue requires triage label Sep 26, 2022
@hsorsky
Copy link
Contributor Author

hsorsky commented Sep 26, 2022

Amazing, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants