-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
sample-props alternate implementation #20350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
adrinjalali
wants to merge
159
commits into
scikit-learn:main
from
adrinjalali:sample-props-alternate
Closed
Changes from all commits
Commits
Show all changes
159 commits
Select commit
Hold shift + click to select a range
b0b6fd1
first try...almost
adrinjalali b33940d
working pipeline
adrinjalali a397748
grid search
adrinjalali 6a3b725
adding some tests
adrinjalali 0b38a88
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali e2ca8da
pep8
adrinjalali 3472b34
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali 5113015
moving function out of base class to validaiton, adding docs
adrinjalali c70474f
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali 8810724
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali 4486ec0
refactor and simplify code
adrinjalali 9b47761
merged master, half way trough scoring
adrinjalali 8a116f5
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali 987b289
first scoring param in GS works
adrinjalali 09ae61f
simplify set_props_request
adrinjalali 2ef3bad
rename to medata_request
adrinjalali 7f8ae6f
fix docstring and None inputs
adrinjalali 5bd1c83
fix scorers' issues
adrinjalali d8b55af
minor cleanup
adrinjalali 5b8fd65
tests are okay with parameters not being passed but requested
adrinjalali 9c1b772
accept old style fit params
adrinjalali 8eca900
make test_pipeline pass, ignore future warnings
adrinjalali 8948d45
include sample_weights in **kwargs in metrics
adrinjalali 1851419
pep8
adrinjalali 421a673
separate _MetadataConsumer
adrinjalali 1a1c2b9
don't pass sample_weight=None in metrics
adrinjalali 50317ee
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali 4967802
cleanup and ignore private attrs set in __init__ in common tests
adrinjalali 32d8020
rfe passes score params in fit, and cleanup
adrinjalali 845204b
fixes to pass model_selection tests
adrinjalali 213db45
adding tests
adrinjalali 0591102
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali fd1c365
fix some test issues
adrinjalali 08bfc9f
fix some test issues
adrinjalali 72029df
test_props passes
adrinjalali 28132b0
change to specific method to consume specific metadata (Joel's proposal)
adrinjalali c2d3683
use build_router_metadata_request
adrinjalali 63fe5e0
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali f701477
Joel's minor suggestions
adrinjalali f973d05
adding SLEP usecases as tests
adrinjalali 8d75b8d
handle params in cross_validate
adrinjalali 1461a4e
slep usecases pass
adrinjalali 29e6b07
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali 6d27c79
generalize default values, create GroupConsuer
adrinjalali 450504d
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali 121d696
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali 554f539
an initial documentation
adrinjalali 2056e6c
fix mocking
adrinjalali 7966b12
fixing tests, broke tests, bisecting
adrinjalali 9da3c2d
slep tests pass again
adrinjalali 115fe9c
fix more issues with metadata routing
adrinjalali eae11f9
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali 4181b92
apply Joel's suggestions
adrinjalali da7d299
get_metadata_request doesn't change any state
adrinjalali 77ce35c
make _passthrough_scorer implement get_metadata_request
adrinjalali 3515951
fix test_search
adrinjalali a124017
_passthrough_scorer is a class, fix test
adrinjalali 59fa589
linear_models: need to request sample weight in tests
adrinjalali 6c86d55
fix ensemble tests
adrinjalali d7d8627
make scorers request sample_weight, and fix inspection tests
adrinjalali cb00075
fix request_sample_weight example in base.py
adrinjalali ae363a4
including metadata routing in user guides
adrinjalali baa3882
use props instead of fit_params in model_selection tests
adrinjalali 849e21e
fix TransformedTargetRegressor and tests
adrinjalali 4b6e331
add metatadata_routing to user_guide toc
adrinjalali 8000b85
add missing import in metadata_routing.rst
adrinjalali 50bf66b
add more missing imports
adrinjalali b5d3e13
fix metadata_routing.rst
adrinjalali b117b66
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali 941bc60
add some MetadataRequest tests
adrinjalali 57fcbc1
add SampleWeightConsumer to StandardScaler
agramfort e4e7313
add more tests
adrinjalali 64345e5
apply changes from comments
adrinjalali 951e75c
validate defaults, fix tests
adrinjalali 0d04e3a
merge upstream/main
adrinjalali 3098ca0
Merge branch 'sample-props' of github.com:adrinjalali/scikit-learn in…
adrinjalali f2266f9
trying new implementation
adrinjalali 477f92f
MetadataRequest seems working
adrinjalali 33e9e54
add {method}_requests methods
adrinjalali 6f5df5f
pipeline uses new MetadataRouter
adrinjalali a5abfb2
basic GS pass
adrinjalali aa52ec5
Merge commit '0e7761cdc4f244adb4803f1a97f0a9fe4b365a99' into sample-p…
adrinjalali bce14b4
MAINT Adds target_version to black config (#20293)
thomasjpfan f5aa0c8
apply black
adrinjalali 6c4670a
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali 161abc4
fix more issues and pass test_invalid_arg_given
adrinjalali c3e6450
remove build_router* tests
adrinjalali c0acf3c
case A passes
adrinjalali 7a7399a
case B passes
adrinjalali 4613667
rename request_sample_weight
adrinjalali ce4cf3c
case D pass
adrinjalali bf68dd0
test_props.py passes
adrinjalali cc55ece
make many more tests pass
adrinjalali 257e351
metrics tests pass
adrinjalali b5a36cc
linear_model tests pass
adrinjalali aefada2
fix docstrings
adrinjalali e97beb2
fix score calls
adrinjalali 2b6071b
ensemble tests pass
adrinjalali 4ec4969
scorer shouldn't validate its inputs
adrinjalali 369a418
model_selection pass
adrinjalali 494a15d
fix docstrings (from Joel's review)
adrinjalali b803772
remove extras from _passthrough_scorer
adrinjalali 6fa083f
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali 8c18233
adding sample weight support to coordinate descent models
adrinjalali 12d2a7e
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali a685bb3
more fixes on linear models tests
adrinjalali 99f1fe4
mostly doc fixes
adrinjalali 933d83c
**kwargs -> kwargs and a fix in _logistic
adrinjalali 295e847
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali ee57e4a
remove output=dict
adrinjalali 455995f
minor doc fix
adrinjalali a26037c
fix doc code
adrinjalali ffe9c96
fix futurewarning in coordinate discent tests
adrinjalali 5f627a0
fix numpydoc check on generated docstrings
adrinjalali 79d4198
fix numpydoc check on generated docstrings
adrinjalali 51c9dff
trying to fix docstring indent
adrinjalali 75a40ff
trying again
adrinjalali 89cf0a5
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali 9fc2d93
minor rst fix
adrinjalali c020c46
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali 3f7c3d7
black test file
adrinjalali 1ab3037
add signature based default values
adrinjalali f8c9070
add removal of implicit args
adrinjalali f9490be
ignore first parameter in the signature
adrinjalali 896b24a
improve user guide
adrinjalali f91b23f
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali c4e9218
EXISTS_NOT -> UNUSED
adrinjalali 898f356
remove extra print statement from test
adrinjalali ab42604
remove SampleWeightConsumer
adrinjalali d106858
add self_metadata and make LogisticRegressionCV validate the input
adrinjalali ed213df
test edge case in scorers
adrinjalali ae12ad8
test deprecated fit_params in cross_validate
adrinjalali cb8d14f
test pipeline with old style and new style of params mixed
adrinjalali 5ae0801
improve metadata_requests.py coverage
adrinjalali 7921d6f
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali efdf7df
add missing docstring param, catch pipeline test FutureWarning
adrinjalali e68c4fc
add missing : to versionadded
adrinjalali 9f68c73
rfe numpydoc fix
adrinjalali 931111c
cleanup, reduce diff, versionadded
adrinjalali 34b06c9
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali 4ee9e08
ENH RFE passes metadata to the underlying fit and score methods
adrinjalali 67de752
add changelog
adrinjalali fb00696
fix extra whitespace
adrinjalali 3e155f0
Merge branch 'rfe-fix' into sample-props-alternate
adrinjalali c621f66
add tests for test_check_no_attributes_set_in_init
adrinjalali 4056721
fix GB's routing and metadata request
adrinjalali 075e810
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali e65848e
fix n_classes check in gp
adrinjalali 4cd44eb
better overwrite semantics
adrinjalali d985145
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali c707d77
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali ee90987
add 'smart' overwrite option
adrinjalali 855064f
fix RFE param validation in fit and score
adrinjalali af1dea5
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali d1ca507
DOC remove link to doc from getting started
adrinjalali b5fc380
improve docstring in TargetRegressor
adrinjalali 1034bce
fix overwrite in _gb.py
adrinjalali 1300867
model -> est in _ridge.py
adrinjalali 7341975
remove score_params from make_scorer
adrinjalali File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
|
||
.. _metadata_routing: | ||
|
||
Metadata Routing | ||
================ | ||
|
||
This guide demonstrates how metadata such as ``sample_weight`` can be routed | ||
and passed along to estimators, scorers, and CV splitters through | ||
meta-estimators such as ``Pipeline`` and ``GridSearchCV``. In order to pass | ||
metadata to a method such as ``fit`` or ``score``, the object accepting the | ||
metadata, must *request* it. For estimators and splitters this is done via | ||
``*_requests`` methods, e.g. ``fit_requests(...)``, and for scorers this is | ||
done via ``score_requests`` method of a scorer. For grouped splitters such as | ||
``GroupKFold`` a ``groups`` parameter is requested by default. This is best | ||
demonstrated by the following examples. | ||
|
||
Usage Examples | ||
************** | ||
Here we present a few examples to show different common usecases. The examples | ||
in this section require the following imports and data:: | ||
|
||
>>> import numpy as np | ||
>>> from sklearn.metrics import make_scorer, accuracy_score | ||
>>> from sklearn.linear_model import LogisticRegressionCV | ||
>>> from sklearn.linear_model import LogisticRegression | ||
>>> from sklearn.model_selection import cross_validate | ||
>>> from sklearn.model_selection import GridSearchCV | ||
>>> from sklearn.model_selection import GroupKFold | ||
>>> from sklearn.feature_selection import SelectKBest | ||
>>> from sklearn.pipeline import make_pipeline | ||
>>> n_samples, n_features = 100, 4 | ||
>>> X = np.random.rand(n_samples, n_features) | ||
>>> y = np.random.randint(0, 2, size=n_samples) | ||
>>> my_groups = np.random.randint(0, 10, size=n_samples) | ||
>>> my_weights = np.random.rand(n_samples) | ||
>>> my_other_weights = np.random.rand(n_samples) | ||
|
||
Weighted scoring and fitting | ||
---------------------------- | ||
|
||
Here ``GroupKFold`` requests ``groups`` by default. However, we need to | ||
explicitly request weights in ``make_scorer`` and for ``LogisticRegressionCV``. | ||
Both of these *consumers* understand the meaning of the key | ||
``"sample_weight"``:: | ||
|
||
>>> weighted_acc = make_scorer(accuracy_score).score_requests( | ||
... sample_weight=True | ||
... ) | ||
>>> lr = LogisticRegressionCV( | ||
... cv=GroupKFold(), scoring=weighted_acc, | ||
... ).fit_requests(sample_weight=True) | ||
>>> cv_results = cross_validate( | ||
... lr, | ||
... X, | ||
... y, | ||
... cv=GroupKFold(), | ||
... props={"sample_weight": my_weights, "groups": my_groups}, | ||
... scoring=weighted_acc, | ||
... ) | ||
|
||
Error handling: if ``props={'sample_weigh': my_weights, ...}`` were passed | ||
(note the typo), cross_validate would raise an error, since 'sample_weigh' was | ||
not requested by any of its children. | ||
|
||
Weighted scoring and unweighted fitting | ||
--------------------------------------- | ||
|
||
Since ``LogisticRegressionCV``, like all scikit-learn estimators, requires that | ||
weights explicitly be requested, we need to explicitly say that | ||
``sample_weight`` is not used for it, so that ``cross_validate`` doesn't pass | ||
it along. | ||
|
||
>>> weighted_acc = make_scorer(accuracy_score).score_requests( | ||
... sample_weight=True | ||
... ) | ||
>>> lr = LogisticRegressionCV( | ||
... cv=GroupKFold(), scoring=weighted_acc, | ||
... ).fit_requests(sample_weight=False) | ||
>>> cv_results = cross_validate( | ||
... lr, | ||
... X, | ||
... y, | ||
... cv=GroupKFold(), | ||
... props={"sample_weight": my_weights, "groups": my_groups}, | ||
... scoring=weighted_acc, | ||
... ) | ||
|
||
Unweighted feature selection | ||
---------------------------- | ||
|
||
Unlike ``LogisticRegressionCV``, ``SelectKBest`` doesn't accept weights and | ||
therefore `"sample_weight"` is not routed to it:: | ||
|
||
>>> weighted_acc = make_scorer(accuracy_score).score_requests( | ||
... sample_weight=True | ||
... ) | ||
>>> lr = LogisticRegressionCV( | ||
... cv=GroupKFold(), scoring=weighted_acc, | ||
... ).fit_requests(sample_weight=True) | ||
>>> sel = SelectKBest(k=2) | ||
>>> pipe = make_pipeline(sel, lr) | ||
>>> cv_results = cross_validate( | ||
... pipe, | ||
... X, | ||
... y, | ||
... cv=GroupKFold(), | ||
... props={"sample_weight": my_weights, "groups": my_groups}, | ||
... scoring=weighted_acc, | ||
... ) | ||
|
||
Different scoring and fitting weights | ||
------------------------------------- | ||
|
||
Despite ``make_scorer`` and ``LogisticRegressionCV`` both expecting a key | ||
``sample_weight``, we can use aliases to pass different weights to different | ||
consumers. In this example, we pass ``scoring_weight`` to the scorer, and | ||
``fitting_weight`` to ``LogisticRegressionCV``:: | ||
|
||
>>> weighted_acc = make_scorer(accuracy_score).score_requests( | ||
... sample_weight="scoring_weight" | ||
... ) | ||
>>> lr = LogisticRegressionCV( | ||
... cv=GroupKFold(), scoring=weighted_acc, | ||
... ).fit_requests(sample_weight="fitting_weight") | ||
>>> cv_results = cross_validate( | ||
... lr, | ||
... X, | ||
... y, | ||
... cv=GroupKFold(), | ||
... props={ | ||
... "scoring_weight": my_weights, | ||
... "fitting_weight": my_other_weights, | ||
... "groups": my_groups, | ||
... }, | ||
... scoring=weighted_acc, | ||
... ) | ||
|
||
API Interface | ||
************* | ||
|
||
A *consumer* is an object (estimator, meta-estimator, scorer, splitter) which | ||
accepts and uses some metadata in at least one of their methods (``fit``, | ||
``predict``, ``inverse_transform``, ``transform``, ``score``, ``split``). | ||
Meta-estimators which only forward the metadata other objects (the child | ||
estimator, scorers, or splitters) and don't use the metadata themselves are not | ||
consumers. (Meta)Estimators which route metadata to other objects are routers. | ||
An (meta)estimator can be a consumer and a router at the same time. | ||
(Meta)Estimators and splitters expose a ``*_requests`` method for each method | ||
which accepts at least one metadata. For instance, if an estimator supports | ||
``sample_weight`` in ``fit`` and ``score``, it exposes | ||
``estimator.fit_requests(sample_weight=value)`` and | ||
``estimator.score_requests(sample_weight=value)``. Here ``value`` can be: | ||
|
||
- ``RequestType.REQUESTED`` or ``True``: method requests a ``sample_weight``. | ||
This means if the metadata is provided, it will be used, otherwise no error | ||
is raised. | ||
- ``RequestType.UNREQUESTED`` or ``False``: method does not request a | ||
``sample_weight``. | ||
- ``RequestType.ERROR_IF_PASSED`` or ``None``: router will raise an error if | ||
``sample_weight`` is passed. This is in almost all cases the default value | ||
when an object is instantiated and ensures the user sets the metadata | ||
requests explicitly when a metadata is passed. | ||
- ``"param_name"``: if this estimator is used in a meta-estimator, the | ||
meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this | ||
estimator. This means the mapping between the metadata required by the | ||
object, e.g. ``sample_weight`` and what is provided by the user, e.g. | ||
``my_weights`` is done at the router level, and not by the object, e.g. | ||
estimator, itself. | ||
|
||
For the scorers, this is done the same way, using ``.score_requests`` method. | ||
|
||
If a metadata, e.g. ``sample_weight`` is passed by the user, the metadata | ||
request for all objects which potentially can accept ``sample_weight`` should | ||
be set by the user, otherwise an error is raised by the router object. For | ||
example, the following code would raise, since it hasn't been explicitly set | ||
whether ``sample_weight`` should be passed to the estimator's scorer or not:: | ||
|
||
>>> param_grid = {"C": [0.1, 1]} | ||
>>> lr = LogisticRegression().fit_requests(sample_weight=True) | ||
>>> try: | ||
... GridSearchCV( | ||
... estimator=lr, param_grid=param_grid | ||
... ).fit(X, y, sample_weight=my_weights) | ||
... except ValueError as e: | ||
... print(e) | ||
sample_weight is passed but is not explicitly set as requested or not. In | ||
method: score |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,3 +30,4 @@ User Guide | |
computing.rst | ||
modules/model_persistence.rst | ||
common_pitfalls.rst | ||
metadata_routing.rst |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the signature checking works, really we don't need this kind of a declaration on a per-name basis. I think these need to match the specification per-method, i.e.
_fit_requests_config = {"sample_weight": REQUESTED, "check_input": REQUESTED}
.check_input
is a very interesting case. Its use cases might need some thinking about...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you propose we handle overriding a
_fit_requests_config
in inheritance? I was hoping to not have to deal with it when we have an attribute per metadata, we pretty much just ignore whatever's been there in the parent.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd probably move it to be an
__init__
parameter.