Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Support usage of predict_params and predict_proba_params in cross validation #24507

Closed
@hsorsky

Description

@hsorsky

Describe the workflow you want to enable

We can currently pass predict_params and predict_proba_params to Pipelines, predictors, etc., at predict time when performing "manual" calls. When performing cross validation, however, there is no way to pass down any params to estimator to use when calling predict/predict_proba. I believe this to be a limitation of the library and think all cross-validation related code should allow passing down of said params to the estimator performing the prediction.

Describe your proposed solution

Given scoring is at the core of all cross validation methods within scikit-learn, I believe it will suffice to update the scorer parent classes to accept a new arg (maybe just predict_params given prediction probabilities is conceptually a subset of predicting in general, but I'd be open to better naming here if people have a preference, or even updating **predict_proba_params to **predict_params) and then work backwards from there.

Doing so, I can see that we'd need to update (new level of nesting indicates methods/functions that call the parent they're nested under):

  • _BaseScorer.__call__ and _BaseScorer._score
    • _score
      • _rfe_single_fit - nothing needed from here as used in RFECV and RFE doesn't have predict_params or predict_proba_params
      • _incremental_fit_estimator
        • learning_curve - generic function to be used on estimator --> needs updating. not used anywhere
      • _fit_and_score
        • evaluate_candidates - created inside of BaseSearchCV.fit --> add to fit signature
        • cross_validate - generic function to be used on estiamtor --> needs updating.
          • cross_val_score - generic function to be used on estiamtor --> needs updating.
            • GraphicalLassoCV.fit - no predict_params to be used
        • learning_curve - already updated from above
        • validation_curve - generic function to be used on estiamtor --> needs updating. not used anywhere

As for how the arg will be added to method/function signatures, I think in general it depends on what the method/function is doing, e.g. use **predict_params in situations where only predicting is occuring, but use predict_params in situations where other logic is also performed (e.g. fitting, scoring, etc). In the all the above methods/functions, I believe they all perform fitting and predicting (either directly or within a sub-call) so I think it'd just be predict_params everywhere. This seems similar to how fit_params is currently passed around.

Describe alternatives you've considered, if relevant

None

Additional context

Related to this comment

I'm happy to take on this work if maintainers are happy for it to be done (I've actually got a branch locally with the relevant code updated, but am yet to update/write new tests in case the proposal is rejected)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions