Thanks to visit codestin.com
Credit goes to github.com

Skip to content

sample-props alternate implementation #20350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
159 commits
Select commit Hold shift + click to select a range
b0b6fd1
first try...almost
adrinjalali Jan 9, 2020
b33940d
working pipeline
adrinjalali Jan 10, 2020
a397748
grid search
adrinjalali Jan 10, 2020
6a3b725
adding some tests
adrinjalali Jan 13, 2020
0b38a88
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Jan 13, 2020
e2ca8da
pep8
adrinjalali Jan 13, 2020
3472b34
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Feb 20, 2020
5113015
moving function out of base class to validaiton, adding docs
adrinjalali Feb 20, 2020
c70474f
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Mar 24, 2020
8810724
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Mar 24, 2020
4486ec0
refactor and simplify code
adrinjalali Mar 28, 2020
9b47761
merged master, half way trough scoring
adrinjalali Jun 23, 2020
8a116f5
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Jul 3, 2020
987b289
first scoring param in GS works
adrinjalali Jul 3, 2020
09ae61f
simplify set_props_request
adrinjalali Jul 3, 2020
2ef3bad
rename to medata_request
adrinjalali Jul 3, 2020
7f8ae6f
fix docstring and None inputs
adrinjalali Jul 3, 2020
5bd1c83
fix scorers' issues
adrinjalali Jul 3, 2020
d8b55af
minor cleanup
adrinjalali Jul 3, 2020
5b8fd65
tests are okay with parameters not being passed but requested
adrinjalali Jul 3, 2020
9c1b772
accept old style fit params
adrinjalali Jul 3, 2020
8eca900
make test_pipeline pass, ignore future warnings
adrinjalali Jul 5, 2020
8948d45
include sample_weights in **kwargs in metrics
adrinjalali Jul 6, 2020
1851419
pep8
adrinjalali Jul 6, 2020
421a673
separate _MetadataConsumer
adrinjalali Jul 7, 2020
1a1c2b9
don't pass sample_weight=None in metrics
adrinjalali Jul 7, 2020
50317ee
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Jul 10, 2020
4967802
cleanup and ignore private attrs set in __init__ in common tests
adrinjalali Jul 10, 2020
32d8020
rfe passes score params in fit, and cleanup
adrinjalali Jul 10, 2020
845204b
fixes to pass model_selection tests
adrinjalali Jul 10, 2020
213db45
adding tests
adrinjalali Sep 24, 2020
0591102
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Sep 24, 2020
fd1c365
fix some test issues
adrinjalali Sep 24, 2020
08bfc9f
fix some test issues
adrinjalali Sep 24, 2020
72029df
test_props passes
adrinjalali Oct 2, 2020
28132b0
change to specific method to consume specific metadata (Joel's proposal)
adrinjalali Oct 6, 2020
c2d3683
use build_router_metadata_request
adrinjalali Oct 7, 2020
63fe5e0
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Oct 7, 2020
f701477
Joel's minor suggestions
adrinjalali Oct 7, 2020
f973d05
adding SLEP usecases as tests
adrinjalali Oct 15, 2020
8d75b8d
handle params in cross_validate
adrinjalali Oct 18, 2020
1461a4e
slep usecases pass
adrinjalali Oct 18, 2020
29e6b07
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Oct 18, 2020
6d27c79
generalize default values, create GroupConsuer
adrinjalali Oct 19, 2020
450504d
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Oct 20, 2020
121d696
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Oct 29, 2020
554f539
an initial documentation
adrinjalali Oct 29, 2020
2056e6c
fix mocking
adrinjalali Oct 29, 2020
7966b12
fixing tests, broke tests, bisecting
adrinjalali Nov 25, 2020
9da3c2d
slep tests pass again
adrinjalali Nov 26, 2020
115fe9c
fix more issues with metadata routing
adrinjalali Nov 26, 2020
eae11f9
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Nov 26, 2020
4181b92
apply Joel's suggestions
adrinjalali Nov 26, 2020
da7d299
get_metadata_request doesn't change any state
adrinjalali Nov 26, 2020
77ce35c
make _passthrough_scorer implement get_metadata_request
adrinjalali Nov 27, 2020
3515951
fix test_search
adrinjalali Nov 27, 2020
a124017
_passthrough_scorer is a class, fix test
adrinjalali Nov 27, 2020
59fa589
linear_models: need to request sample weight in tests
adrinjalali Nov 27, 2020
6c86d55
fix ensemble tests
adrinjalali Nov 27, 2020
d7d8627
make scorers request sample_weight, and fix inspection tests
adrinjalali Nov 27, 2020
cb00075
fix request_sample_weight example in base.py
adrinjalali Nov 27, 2020
ae363a4
including metadata routing in user guides
adrinjalali Nov 27, 2020
baa3882
use props instead of fit_params in model_selection tests
adrinjalali Nov 27, 2020
849e21e
fix TransformedTargetRegressor and tests
adrinjalali Nov 27, 2020
4b6e331
add metatadata_routing to user_guide toc
adrinjalali Nov 27, 2020
8000b85
add missing import in metadata_routing.rst
adrinjalali Nov 27, 2020
50bf66b
add more missing imports
adrinjalali Nov 27, 2020
b5d3e13
fix metadata_routing.rst
adrinjalali Nov 28, 2020
b117b66
Merge remote-tracking branch 'upstream/master' into sample-props
adrinjalali Nov 28, 2020
941bc60
add some MetadataRequest tests
adrinjalali Nov 30, 2020
57fcbc1
add SampleWeightConsumer to StandardScaler
agramfort Dec 17, 2020
e4e7313
add more tests
adrinjalali Jan 22, 2021
64345e5
apply changes from comments
adrinjalali Mar 8, 2021
951e75c
validate defaults, fix tests
adrinjalali Mar 9, 2021
0d04e3a
merge upstream/main
adrinjalali Mar 9, 2021
3098ca0
Merge branch 'sample-props' of github.com:adrinjalali/scikit-learn in…
adrinjalali Mar 9, 2021
f2266f9
trying new implementation
adrinjalali Apr 18, 2021
477f92f
MetadataRequest seems working
adrinjalali Jun 17, 2021
33e9e54
add {method}_requests methods
adrinjalali Jun 20, 2021
6f5df5f
pipeline uses new MetadataRouter
adrinjalali Jun 22, 2021
a5abfb2
basic GS pass
adrinjalali Jun 23, 2021
aa52ec5
Merge commit '0e7761cdc4f244adb4803f1a97f0a9fe4b365a99' into sample-p…
adrinjalali Jun 23, 2021
bce14b4
MAINT Adds target_version to black config (#20293)
thomasjpfan Jun 17, 2021
f5aa0c8
apply black
adrinjalali Jun 23, 2021
6c4670a
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Jun 23, 2021
161abc4
fix more issues and pass test_invalid_arg_given
adrinjalali Jun 23, 2021
c3e6450
remove build_router* tests
adrinjalali Jun 23, 2021
c0acf3c
case A passes
adrinjalali Jun 23, 2021
7a7399a
case B passes
adrinjalali Jun 23, 2021
4613667
rename request_sample_weight
adrinjalali Jun 23, 2021
ce4cf3c
case D pass
adrinjalali Jun 23, 2021
bf68dd0
test_props.py passes
adrinjalali Jun 23, 2021
cc55ece
make many more tests pass
adrinjalali Jun 24, 2021
257e351
metrics tests pass
adrinjalali Jun 24, 2021
b5a36cc
linear_model tests pass
adrinjalali Jun 24, 2021
aefada2
fix docstrings
adrinjalali Jun 25, 2021
e97beb2
fix score calls
adrinjalali Jun 25, 2021
2b6071b
ensemble tests pass
adrinjalali Jun 25, 2021
4ec4969
scorer shouldn't validate its inputs
adrinjalali Jun 25, 2021
369a418
model_selection pass
adrinjalali Jun 25, 2021
494a15d
fix docstrings (from Joel's review)
adrinjalali Jun 27, 2021
b803772
remove extras from _passthrough_scorer
adrinjalali Jun 27, 2021
6fa083f
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Jun 27, 2021
8c18233
adding sample weight support to coordinate descent models
adrinjalali Jun 27, 2021
12d2a7e
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Jun 27, 2021
a685bb3
more fixes on linear models tests
adrinjalali Jun 28, 2021
99f1fe4
mostly doc fixes
adrinjalali Jun 28, 2021
933d83c
**kwargs -> kwargs and a fix in _logistic
adrinjalali Jun 28, 2021
295e847
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Jun 28, 2021
ee57e4a
remove output=dict
adrinjalali Jun 28, 2021
455995f
minor doc fix
adrinjalali Jun 28, 2021
a26037c
fix doc code
adrinjalali Jun 28, 2021
ffe9c96
fix futurewarning in coordinate discent tests
adrinjalali Jun 28, 2021
5f627a0
fix numpydoc check on generated docstrings
adrinjalali Jun 28, 2021
79d4198
fix numpydoc check on generated docstrings
adrinjalali Jun 28, 2021
51c9dff
trying to fix docstring indent
adrinjalali Jun 28, 2021
75a40ff
trying again
adrinjalali Jun 28, 2021
89cf0a5
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Jun 29, 2021
9fc2d93
minor rst fix
adrinjalali Jun 30, 2021
c020c46
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Jun 30, 2021
3f7c3d7
black test file
adrinjalali Jun 30, 2021
1ab3037
add signature based default values
adrinjalali Jun 30, 2021
f8c9070
add removal of implicit args
adrinjalali Jul 1, 2021
f9490be
ignore first parameter in the signature
adrinjalali Jul 1, 2021
896b24a
improve user guide
adrinjalali Jul 5, 2021
f91b23f
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Jul 5, 2021
c4e9218
EXISTS_NOT -> UNUSED
adrinjalali Jul 5, 2021
898f356
remove extra print statement from test
adrinjalali Jul 5, 2021
ab42604
remove SampleWeightConsumer
adrinjalali Jul 5, 2021
d106858
add self_metadata and make LogisticRegressionCV validate the input
adrinjalali Jul 6, 2021
ed213df
test edge case in scorers
adrinjalali Jul 6, 2021
ae12ad8
test deprecated fit_params in cross_validate
adrinjalali Jul 7, 2021
cb8d14f
test pipeline with old style and new style of params mixed
adrinjalali Jul 7, 2021
5ae0801
improve metadata_requests.py coverage
adrinjalali Jul 7, 2021
7921d6f
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Jul 7, 2021
efdf7df
add missing docstring param, catch pipeline test FutureWarning
adrinjalali Jul 8, 2021
e68c4fc
add missing : to versionadded
adrinjalali Jul 8, 2021
9f68c73
rfe numpydoc fix
adrinjalali Jul 8, 2021
931111c
cleanup, reduce diff, versionadded
adrinjalali Jul 9, 2021
34b06c9
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Jul 9, 2021
4ee9e08
ENH RFE passes metadata to the underlying fit and score methods
adrinjalali Jul 9, 2021
67de752
add changelog
adrinjalali Jul 9, 2021
fb00696
fix extra whitespace
adrinjalali Jul 9, 2021
3e155f0
Merge branch 'rfe-fix' into sample-props-alternate
adrinjalali Jul 9, 2021
c621f66
add tests for test_check_no_attributes_set_in_init
adrinjalali Jul 10, 2021
4056721
fix GB's routing and metadata request
adrinjalali Jul 22, 2021
075e810
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Jul 22, 2021
e65848e
fix n_classes check in gp
adrinjalali Jul 23, 2021
4cd44eb
better overwrite semantics
adrinjalali Aug 2, 2021
d985145
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Aug 5, 2021
c707d77
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Sep 27, 2021
ee90987
add 'smart' overwrite option
adrinjalali Oct 1, 2021
855064f
fix RFE param validation in fit and score
adrinjalali Oct 1, 2021
af1dea5
Merge remote-tracking branch 'upstream/main' into sample-props-alternate
adrinjalali Oct 1, 2021
d1ca507
DOC remove link to doc from getting started
adrinjalali Oct 1, 2021
b5fc380
improve docstring in TargetRegressor
adrinjalali Oct 1, 2021
1034bce
fix overwrite in _gb.py
adrinjalali Oct 1, 2021
1300867
model -> est in _ridge.py
adrinjalali Oct 1, 2021
7341975
remove score_params from make_scorer
adrinjalali Oct 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions doc/metadata_routing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@

.. _metadata_routing:

Metadata Routing
================

This guide demonstrates how metadata such as ``sample_weight`` can be routed
and passed along to estimators, scorers, and CV splitters through
meta-estimators such as ``Pipeline`` and ``GridSearchCV``. In order to pass
metadata to a method such as ``fit`` or ``score``, the object accepting the
metadata, must *request* it. For estimators and splitters this is done via
``*_requests`` methods, e.g. ``fit_requests(...)``, and for scorers this is
done via ``score_requests`` method of a scorer. For grouped splitters such as
``GroupKFold`` a ``groups`` parameter is requested by default. This is best
demonstrated by the following examples.

Usage Examples
**************
Here we present a few examples to show different common usecases. The examples
in this section require the following imports and data::

>>> import numpy as np
>>> from sklearn.metrics import make_scorer, accuracy_score
>>> from sklearn.linear_model import LogisticRegressionCV
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import cross_validate
>>> from sklearn.model_selection import GridSearchCV
>>> from sklearn.model_selection import GroupKFold
>>> from sklearn.feature_selection import SelectKBest
>>> from sklearn.pipeline import make_pipeline
>>> n_samples, n_features = 100, 4
>>> X = np.random.rand(n_samples, n_features)
>>> y = np.random.randint(0, 2, size=n_samples)
>>> my_groups = np.random.randint(0, 10, size=n_samples)
>>> my_weights = np.random.rand(n_samples)
>>> my_other_weights = np.random.rand(n_samples)

Weighted scoring and fitting
----------------------------

Here ``GroupKFold`` requests ``groups`` by default. However, we need to
explicitly request weights in ``make_scorer`` and for ``LogisticRegressionCV``.
Both of these *consumers* understand the meaning of the key
``"sample_weight"``::

>>> weighted_acc = make_scorer(accuracy_score).score_requests(
... sample_weight=True
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).fit_requests(sample_weight=True)
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... props={"sample_weight": my_weights, "groups": my_groups},
... scoring=weighted_acc,
... )

Error handling: if ``props={'sample_weigh': my_weights, ...}`` were passed
(note the typo), cross_validate would raise an error, since 'sample_weigh' was
not requested by any of its children.

Weighted scoring and unweighted fitting
---------------------------------------

Since ``LogisticRegressionCV``, like all scikit-learn estimators, requires that
weights explicitly be requested, we need to explicitly say that
``sample_weight`` is not used for it, so that ``cross_validate`` doesn't pass
it along.

>>> weighted_acc = make_scorer(accuracy_score).score_requests(
... sample_weight=True
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).fit_requests(sample_weight=False)
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... props={"sample_weight": my_weights, "groups": my_groups},
... scoring=weighted_acc,
... )

Unweighted feature selection
----------------------------

Unlike ``LogisticRegressionCV``, ``SelectKBest`` doesn't accept weights and
therefore `"sample_weight"` is not routed to it::

>>> weighted_acc = make_scorer(accuracy_score).score_requests(
... sample_weight=True
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).fit_requests(sample_weight=True)
>>> sel = SelectKBest(k=2)
>>> pipe = make_pipeline(sel, lr)
>>> cv_results = cross_validate(
... pipe,
... X,
... y,
... cv=GroupKFold(),
... props={"sample_weight": my_weights, "groups": my_groups},
... scoring=weighted_acc,
... )

Different scoring and fitting weights
-------------------------------------

Despite ``make_scorer`` and ``LogisticRegressionCV`` both expecting a key
``sample_weight``, we can use aliases to pass different weights to different
consumers. In this example, we pass ``scoring_weight`` to the scorer, and
``fitting_weight`` to ``LogisticRegressionCV``::

>>> weighted_acc = make_scorer(accuracy_score).score_requests(
... sample_weight="scoring_weight"
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).fit_requests(sample_weight="fitting_weight")
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... props={
... "scoring_weight": my_weights,
... "fitting_weight": my_other_weights,
... "groups": my_groups,
... },
... scoring=weighted_acc,
... )

API Interface
*************

A *consumer* is an object (estimator, meta-estimator, scorer, splitter) which
accepts and uses some metadata in at least one of their methods (``fit``,
``predict``, ``inverse_transform``, ``transform``, ``score``, ``split``).
Meta-estimators which only forward the metadata other objects (the child
estimator, scorers, or splitters) and don't use the metadata themselves are not
consumers. (Meta)Estimators which route metadata to other objects are routers.
An (meta)estimator can be a consumer and a router at the same time.
(Meta)Estimators and splitters expose a ``*_requests`` method for each method
which accepts at least one metadata. For instance, if an estimator supports
``sample_weight`` in ``fit`` and ``score``, it exposes
``estimator.fit_requests(sample_weight=value)`` and
``estimator.score_requests(sample_weight=value)``. Here ``value`` can be:

- ``RequestType.REQUESTED`` or ``True``: method requests a ``sample_weight``.
This means if the metadata is provided, it will be used, otherwise no error
is raised.
- ``RequestType.UNREQUESTED`` or ``False``: method does not request a
``sample_weight``.
- ``RequestType.ERROR_IF_PASSED`` or ``None``: router will raise an error if
``sample_weight`` is passed. This is in almost all cases the default value
when an object is instantiated and ensures the user sets the metadata
requests explicitly when a metadata is passed.
- ``"param_name"``: if this estimator is used in a meta-estimator, the
meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this
estimator. This means the mapping between the metadata required by the
object, e.g. ``sample_weight`` and what is provided by the user, e.g.
``my_weights`` is done at the router level, and not by the object, e.g.
estimator, itself.

For the scorers, this is done the same way, using ``.score_requests`` method.

If a metadata, e.g. ``sample_weight`` is passed by the user, the metadata
request for all objects which potentially can accept ``sample_weight`` should
be set by the user, otherwise an error is raised by the router object. For
example, the following code would raise, since it hasn't been explicitly set
whether ``sample_weight`` should be passed to the estimator's scorer or not::

>>> param_grid = {"C": [0.1, 1]}
>>> lr = LogisticRegression().fit_requests(sample_weight=True)
>>> try:
... GridSearchCV(
... estimator=lr, param_grid=param_grid
... ).fit(X, y, sample_weight=my_weights)
... except ValueError as e:
... print(e)
sample_weight is passed but is not explicitly set as requested or not. In
method: score
1 change: 1 addition & 0 deletions doc/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ User Guide
computing.rst
modules/model_persistence.rst
common_pitfalls.rst
metadata_routing.rst
9 changes: 8 additions & 1 deletion sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .utils.validation import _num_features
from .utils.validation import _check_feature_names_in
from .utils._estimator_html_repr import estimator_html_repr
from .utils.metadata_requests import _MetadataRequester
from .utils.validation import _get_feature_names


Expand Down Expand Up @@ -79,7 +80,13 @@ def clone(estimator, *, safe=True):
new_object_params = estimator.get_params(deep=False)
for name, param in new_object_params.items():
new_object_params[name] = clone(param, safe=False)

new_object = klass(**new_object_params)
try:
new_object._metadata_request = copy.deepcopy(estimator._metadata_request)
except AttributeError:
pass

params_set = new_object.get_params(deep=False)

# quick sanity check of the parameters of the clone
Expand Down Expand Up @@ -144,7 +151,7 @@ def _pprint(params, offset=0, printer=repr):
return lines


class BaseEstimator:
class BaseEstimator(_MetadataRequester):
"""Base class for all estimators in scikit-learn.

Notes
Expand Down
16 changes: 16 additions & 0 deletions sklearn/compose/_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,19 @@ def n_features_in_(self):
) from nfe

return self.regressor_.n_features_in_

def get_metadata_request(self):
"""Get requested data properties.

This method mirrors the given regressor's metadata request.

.. versionadded:: 1.1

Returns
-------
request : dict
A dict of dict of str->value. The key to the first dict is the name
of the method, and the key to the second dict is the name of the
argument requested by the method.
"""
return self.regressor.get_metadata_request()
12 changes: 9 additions & 3 deletions sklearn/compose/tests/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ def test_transform_target_regressor_count_fit(check_inverse):


class DummyRegressorWithExtraFitParams(DummyRegressor):
_metadata_request__check_input = {"fit": "check_input"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the signature checking works, really we don't need this kind of a declaration on a per-name basis. I think these need to match the specification per-method, i.e. _fit_requests_config = {"sample_weight": REQUESTED, "check_input": REQUESTED}.

check_input is a very interesting case. Its use cases might need some thinking about...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you propose we handle overriding a _fit_requests_config in inheritance? I was hoping to not have to deal with it when we have an attribute per metadata, we pretty much just ignore whatever's been there in the parent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_input is a very interesting case. Its use cases might need some thinking about...

I'd probably move it to be an __init__ parameter.


def fit(self, X, y, sample_weight=None, check_input=True):
# on the test below we force this to false, we make sure this is
# actually passed to the regressor
Expand All @@ -356,7 +358,10 @@ def fit(self, X, y, sample_weight=None, check_input=True):
def test_transform_target_regressor_pass_fit_parameters():
X, y = friedman
regr = TransformedTargetRegressor(
regressor=DummyRegressorWithExtraFitParams(), transformer=DummyTransformer()
regressor=DummyRegressorWithExtraFitParams().fit_requests(
sample_weight=True, check_input=True
),
transformer=DummyTransformer(),
)

regr.fit(X, y, check_input=False)
Expand All @@ -367,12 +372,13 @@ def test_transform_target_regressor_route_pipeline():
X, y = friedman

regr = TransformedTargetRegressor(
regressor=DummyRegressorWithExtraFitParams(), transformer=DummyTransformer()
regressor=DummyRegressorWithExtraFitParams().fit_requests(check_input=True),
transformer=DummyTransformer(),
)
estimators = [("normalize", StandardScaler()), ("est", regr)]

pip = Pipeline(estimators)
pip.fit(X, y, **{"est__check_input": False})
pip.fit(X, y, **{"check_input": False})

assert regr.transformer_.fit_counter == 1

Expand Down
Loading