From dd72b396dcabcb0006ca69c91b762d934f5e0150 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Tue, 25 Jul 2023 10:12:46 +0200 Subject: [PATCH 1/2] MNT move common metadata routing test objects --- sklearn/tests/metadata_routing_common.py | 394 ++++++++++++++++++ sklearn/tests/test_metadata_routing.py | 357 ++++------------ .../test_metaestimators_metadata_routing.py | 186 +-------- 3 files changed, 484 insertions(+), 453 deletions(-) create mode 100644 sklearn/tests/metadata_routing_common.py diff --git a/sklearn/tests/metadata_routing_common.py b/sklearn/tests/metadata_routing_common.py new file mode 100644 index 0000000000000..fdfc235f6ad04 --- /dev/null +++ b/sklearn/tests/metadata_routing_common.py @@ -0,0 +1,394 @@ +from functools import partial + +import numpy as np + +from sklearn.base import ( + BaseEstimator, + ClassifierMixin, + MetaEstimatorMixin, + RegressorMixin, + TransformerMixin, + clone, +) +from sklearn.metrics._scorer import _BaseScorer +from sklearn.model_selection import BaseCrossValidator +from sklearn.model_selection._split import GroupsConsumerMixin +from sklearn.utils._metadata_requests import ( + SIMPLE_METHODS, +) +from sklearn.utils.metadata_routing import ( + MetadataRouter, + process_routing, +) + + +def record_metadata(obj, method, record_default=True, **kwargs): + """Utility function to store passed metadata to a method. + + If record_default is False, kwargs whose values are "default" are skipped. + This is so that checks on keyword arguments whose default was not changed + are skipped. + + """ + if not hasattr(obj, "_records"): + obj._records = {} + if not record_default: + kwargs = { + key: val + for key, val in kwargs.items() + if not isinstance(val, str) or (val != "default") + } + obj._records[method] = kwargs + + +def check_recorded_metadata(obj, method, split_params=tuple(), **kwargs): + """Check whether the expected metadata is passed to the object's method. + + Parameters + ---------- + split_params : tuple, default=empty + specifies any parameters which are to be checked as being a subset + of the original values. + + """ + records = getattr(obj, "_records", dict()).get(method, dict()) + assert set(kwargs.keys()) == set(records.keys()) + for key, value in kwargs.items(): + recorded_value = records[key] + # The following condition is used to check for any specified parameters + # being a subset of the original values + if key in split_params and recorded_value is not None: + assert np.isin(recorded_value, value).all() + else: + assert recorded_value is value + + +record_metadata_not_default = partial(record_metadata, record_default=False) + + +def assert_request_is_empty(metadata_request, exclude=None): + """Check if a metadata request dict is empty. + + One can exclude a method or a list of methods from the check using the + ``exclude`` parameter. + """ + if isinstance(metadata_request, MetadataRouter): + for _, route_mapping in metadata_request: + assert_request_is_empty(route_mapping.router) + return + + exclude = [] if exclude is None else exclude + for method in SIMPLE_METHODS: + if method in exclude: + continue + mmr = getattr(metadata_request, method) + props = [ + prop + for prop, alias in mmr.requests.items() + if isinstance(alias, str) or alias is not None + ] + assert not len(props) + + +def assert_request_equal(request, dictionary): + for method, requests in dictionary.items(): + mmr = getattr(request, method) + assert mmr.requests == requests + + empty_methods = [method for method in SIMPLE_METHODS if method not in dictionary] + for method in empty_methods: + assert not len(getattr(request, method).requests) + + +class _Registry(list): + # This list is used to get a reference to the sub-estimators, which are not + # necessarily stored on the metaestimator. We need to override __deepcopy__ + # because the sub-estimators are probably cloned, which would result in a + # new copy of the list, but we need copy and deep copy both to return the + # same instance. + def __deepcopy__(self, memo): + return self + + def __copy__(self): + return self + + +class ConsumingRegressor(RegressorMixin, BaseEstimator): + """A regressor consuming metadata. + + Parameters + ---------- + registry : list, default=None + If a list, the estimator will append itself to the list in order to have + a reference to the estimator later on. Since that reference is not + required in all tests, registration can be skipped by leaving this value + as None. + + """ + + def __init__(self, registry=None): + self.registry = registry + + def partial_fit(self, X, y, sample_weight="default", metadata="default"): + if self.registry is not None: + self.registry.append(self) + + record_metadata_not_default( + self, "partial_fit", sample_weight=sample_weight, metadata=metadata + ) + return self + + def fit(self, X, y, sample_weight="default", metadata="default"): + if self.registry is not None: + self.registry.append(self) + + record_metadata_not_default( + self, "fit", sample_weight=sample_weight, metadata=metadata + ) + return self + + def predict(self, X, sample_weight="default", metadata="default"): + pass # pragma: no cover + + # when needed, uncomment the implementation + # if self.registry is not None: + # self.registry.append(self) + + # record_metadata_not_default( + # self, "predict", sample_weight=sample_weight, metadata=metadata + # ) + # return np.zeros(shape=(len(X),)) + + +class NonConsumingClassifier(ClassifierMixin, BaseEstimator): + """A classifier which accepts no metadata on any method.""" + + def fit(self, X, y): + self.classes_ = [0, 1] + return self + + def predict(self, X): + return np.ones(len(X)) # pragma: no cover + + +class ConsumingClassifier(ClassifierMixin, BaseEstimator): + """A classifier consuming metadata. + + Parameters + ---------- + registry : list, default=None + If a list, the estimator will append itself to the list in order to have + a reference to the estimator later on. Since that reference is not + required in all tests, registration can be skipped by leaving this value + as None. + + """ + + def __init__(self, registry=None): + self.registry = registry + + def partial_fit(self, X, y, sample_weight="default", metadata="default"): + if self.registry is not None: + self.registry.append(self) + + record_metadata_not_default( + self, "partial_fit", sample_weight=sample_weight, metadata=metadata + ) + self.classes_ = [0, 1] + return self + + def fit(self, X, y, sample_weight="default", metadata="default"): + if self.registry is not None: + self.registry.append(self) + + record_metadata_not_default( + self, "fit", sample_weight=sample_weight, metadata=metadata + ) + self.classes_ = [0, 1] + return self + + def predict(self, X, sample_weight="default", metadata="default"): + pass # pragma: no cover + + # when needed, uncomment the implementation + # if self.registry is not None: + # self.registry.append(self) + + # record_metadata_not_default( + # self, "predict", sample_weight=sample_weight, metadata=metadata + # ) + # return np.zeros(shape=(len(X),)) + + def predict_proba(self, X, sample_weight="default", metadata="default"): + if self.registry is not None: + self.registry.append(self) + + record_metadata_not_default( + self, "predict_proba", sample_weight=sample_weight, metadata=metadata + ) + return np.asarray([[0.0, 1.0]] * len(X)) + + def predict_log_proba(self, X, sample_weight="default", metadata="default"): + pass # pragma: no cover + + # when needed, uncomment the implementation + # if self.registry is not None: + # self.registry.append(self) + + # record_metadata_not_default( + # self, "predict_log_proba", sample_weight=sample_weight, metadata=metadata + # ) + # return np.zeros(shape=(len(X), 2)) + + +class ConsumingTransformer(TransformerMixin, BaseEstimator): + """A transformer which accepts metadata on fit and transform. + + Parameters + ---------- + registry : list, default=None + If a list, the estimator will append itself to the list in order to have + a reference to the estimator later on. Since that reference is not + required in all tests, registration can be skipped by leaving this value + as None. + """ + + def __init__(self, registry=None): + self.registry = registry + + def fit(self, X, y=None, sample_weight=None, metadata=None): + if self.registry is not None: + self.registry.append(self) + + record_metadata_not_default( + self, "fit", sample_weight=sample_weight, metadata=metadata + ) + return self + + def transform(self, X, sample_weight=None): + record_metadata(self, "transform", sample_weight=sample_weight) + return X + + +class ConsumingScorer(_BaseScorer): + def __init__(self, registry=None): + super().__init__(score_func="test", sign=1, kwargs={}) + self.registry = registry + + def __call__( + self, estimator, X, y_true, sample_weight="default", metadata="default" + ): + if self.registry is not None: + self.registry.append(self) + + record_metadata_not_default( + self, "score", sample_weight=sample_weight, metadata=metadata + ) + + return 0.0 + + +class ConsumingSplitter(BaseCrossValidator, GroupsConsumerMixin): + def __init__(self, registry=None): + self.registry = registry + + def split(self, X, y=None, groups="default"): + if self.registry is not None: + self.registry.append(self) + + record_metadata_not_default(self, "split", groups=groups) + + split_index = len(X) - 10 + train_indices = range(0, split_index) + test_indices = range(split_index, len(X)) + yield test_indices, train_indices + + def get_n_splits(self, X=None, y=None, groups=None): + pass # pragma: no cover + + def _iter_test_indices(self, X=None, y=None, groups=None): + pass # pragma: no cover + + +class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): + """A meta-regressor which is only a router.""" + + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, **fit_params): + params = process_routing(self, "fit", fit_params) + self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) + + def get_metadata_routing(self): + router = MetadataRouter(owner=self.__class__.__name__).add( + estimator=self.estimator, method_mapping="one-to-one" + ) + return router + + +class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): + """A meta-regressor which is also a consumer.""" + + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, sample_weight=None, **fit_params): + record_metadata(self, "fit", sample_weight=sample_weight) + params = process_routing(self, "fit", fit_params, sample_weight=sample_weight) + self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) + return self + + def predict(self, X, **predict_params): + params = process_routing(self, "predict", predict_params) + return self.estimator_.predict(X, **params.estimator.predict) + + def get_metadata_routing(self): + router = ( + MetadataRouter(owner=self.__class__.__name__) + .add_self_request(self) + .add(estimator=self.estimator, method_mapping="one-to-one") + ) + return router + + +class WeightedMetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): + """A meta-estimator which also consumes sample_weight itself in ``fit``.""" + + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, sample_weight=None, **kwargs): + record_metadata(self, "fit", sample_weight=sample_weight) + params = process_routing(self, "fit", kwargs, sample_weight=sample_weight) + self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) + return self + + def get_metadata_routing(self): + router = ( + MetadataRouter(owner=self.__class__.__name__) + .add_self_request(self) + .add(estimator=self.estimator, method_mapping="fit") + ) + return router + + +class MetaTransformer(MetaEstimatorMixin, TransformerMixin, BaseEstimator): + """A simple meta-transformer.""" + + def __init__(self, transformer): + self.transformer = transformer + + def fit(self, X, y=None, **fit_params): + params = process_routing(self, "fit", fit_params) + self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit) + return self + + def transform(self, X, y=None, **transform_params): + params = process_routing(self, "transform", transform_params) + return self.transformer_.transform(X, **params.transformer.transform) + + def get_metadata_routing(self): + return MetadataRouter(owner=self.__class__.__name__).add( + transformer=self.transformer, method_mapping="one-to-one" + ) diff --git a/sklearn/tests/test_metadata_routing.py b/sklearn/tests/test_metadata_routing.py index 3fc6a9c337f47..83bfbfc012780 100644 --- a/sklearn/tests/test_metadata_routing.py +++ b/sklearn/tests/test_metadata_routing.py @@ -13,13 +13,22 @@ from sklearn import config_context from sklearn.base import ( BaseEstimator, - ClassifierMixin, - MetaEstimatorMixin, - RegressorMixin, - TransformerMixin, clone, ) from sklearn.linear_model import LinearRegression +from sklearn.tests.metadata_routing_common import ( + ConsumingClassifier, + ConsumingRegressor, + ConsumingTransformer, + MetaRegressor, + MetaTransformer, + NonConsumingClassifier, + WeightedMetaClassifier, + WeightedMetaRegressor, + assert_request_equal, + assert_request_is_empty, + check_recorded_metadata, +) from sklearn.utils import metadata_routing from sklearn.utils._metadata_requests import ( COMPOSITE_METHODS, @@ -56,209 +65,6 @@ def enable_slep006(): yield -def assert_request_is_empty(metadata_request, exclude=None): - """Check if a metadata request dict is empty. - - One can exclude a method or a list of methods from the check using the - ``exclude`` parameter. - """ - if isinstance(metadata_request, MetadataRouter): - for _, route_mapping in metadata_request: - assert_request_is_empty(route_mapping.router) - return - - exclude = [] if exclude is None else exclude - for method in SIMPLE_METHODS: - if method in exclude: - continue - mmr = getattr(metadata_request, method) - props = [ - prop - for prop, alias in mmr.requests.items() - if isinstance(alias, str) or alias is not None - ] - assert not len(props) - - -def assert_request_equal(request, dictionary): - for method, requests in dictionary.items(): - mmr = getattr(request, method) - assert mmr.requests == requests - - empty_methods = [method for method in SIMPLE_METHODS if method not in dictionary] - for method in empty_methods: - assert not len(getattr(request, method).requests) - - -def record_metadata(obj, method, record_default=True, **kwargs): - """Utility function to store passed metadata to a method. - - If record_default is False, kwargs whose values are "default" are skipped. - This is so that checks on keyword arguments whose default was not changed - are skipped. - - """ - if not hasattr(obj, "_records"): - obj._records = {} - if not record_default: - kwargs = { - key: val - for key, val in kwargs.items() - if not isinstance(val, str) or (val != "default") - } - obj._records[method] = kwargs - - -def check_recorded_metadata(obj, method, split_params=tuple(), **kwargs): - """Check whether the expected metadata is passed to the object's method. - - Parameters - ---------- - split_params : tuple, default=empty - specifies any parameters which are to be checked as being a subset - of the original values. - - """ - records = getattr(obj, "_records", dict()).get(method, dict()) - assert set(kwargs.keys()) == set(records.keys()) - for key, value in kwargs.items(): - recorded_value = records[key] - # The following condition is used to check for any specified parameters - # being a subset of the original values - if key in split_params and recorded_value is not None: - assert np.isin(recorded_value, value).all() - else: - assert recorded_value is value - - -class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): - """A meta-regressor which is only a router.""" - - def __init__(self, estimator): - self.estimator = estimator - - def fit(self, X, y, **fit_params): - params = process_routing(self, "fit", fit_params) - self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) - - def get_metadata_routing(self): - router = MetadataRouter(owner=self.__class__.__name__).add( - estimator=self.estimator, method_mapping="one-to-one" - ) - return router - - -class RegressorMetadata(RegressorMixin, BaseEstimator): - """A regressor consuming a metadata.""" - - def fit(self, X, y, sample_weight=None): - record_metadata(self, "fit", sample_weight=sample_weight) - return self - - def predict(self, X): - return np.zeros(shape=(len(X))) - - -class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): - """A meta-regressor which is also a consumer.""" - - def __init__(self, estimator): - self.estimator = estimator - - def fit(self, X, y, sample_weight=None, **fit_params): - record_metadata(self, "fit", sample_weight=sample_weight) - params = process_routing(self, "fit", fit_params, sample_weight=sample_weight) - self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) - return self - - def predict(self, X, **predict_params): - params = process_routing(self, "predict", predict_params) - return self.estimator_.predict(X, **params.estimator.predict) - - def get_metadata_routing(self): - router = ( - MetadataRouter(owner=self.__class__.__name__) - .add_self_request(self) - .add(estimator=self.estimator, method_mapping="one-to-one") - ) - return router - - -class ClassifierNoMetadata(ClassifierMixin, BaseEstimator): - """An estimator which accepts no metadata on any method.""" - - def fit(self, X, y): - return self - - def predict(self, X): - return np.ones(len(X)) # pragma: no cover - - -class ClassifierFitMetadata(ClassifierMixin, BaseEstimator): - """An estimator accepting two metadata in its ``fit`` method.""" - - def fit(self, X, y, sample_weight=None, brand=None): - record_metadata(self, "fit", sample_weight=sample_weight, brand=brand) - return self - - def predict(self, X): - return np.ones(len(X)) # pragma: no cover - - -class SimpleMetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): - """A meta-estimator which also consumes sample_weight itself in ``fit``.""" - - def __init__(self, estimator): - self.estimator = estimator - - def fit(self, X, y, sample_weight=None, **kwargs): - record_metadata(self, "fit", sample_weight=sample_weight) - params = process_routing(self, "fit", kwargs, sample_weight=sample_weight) - self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) - return self - - def get_metadata_routing(self): - router = ( - MetadataRouter(owner=self.__class__.__name__) - .add_self_request(self) - .add(estimator=self.estimator, method_mapping="fit") - ) - return router - - -class TransformerMetadata(TransformerMixin, BaseEstimator): - """A transformer which accepts metadata on fit and transform.""" - - def fit(self, X, y=None, brand=None, sample_weight=None): - record_metadata(self, "fit", brand=brand, sample_weight=sample_weight) - return self - - def transform(self, X, sample_weight=None): - record_metadata(self, "transform", sample_weight=sample_weight) - return X - - -class MetaTransformer(MetaEstimatorMixin, TransformerMixin, BaseEstimator): - """A simple meta-transformer.""" - - def __init__(self, transformer): - self.transformer = transformer - - def fit(self, X, y=None, **fit_params): - params = process_routing(self, "fit", fit_params) - self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit) - return self - - def transform(self, X, y=None, **transform_params): - params = process_routing(self, "transform", transform_params) - return self.transformer_.transform(X, **params.transformer.transform) - - def get_metadata_routing(self): - return MetadataRouter(owner=self.__class__.__name__).add( - transformer=self.transformer, method_mapping="one-to-one" - ) - - class SimplePipeline(BaseEstimator): """A very simple pipeline, assuming the last step is always a predictor.""" @@ -334,7 +140,7 @@ def test_assert_request_is_empty(): assert_request_is_empty( MetadataRouter(owner="test") .add_self_request(WeightedMetaRegressor(estimator=None)) - .add(method_mapping="fit", estimator=RegressorMetadata()) + .add(method_mapping="fit", estimator=ConsumingRegressor()) ) @@ -383,30 +189,30 @@ class OddEstimator(BaseEstimator): assert odd_request.fit.requests == {"sample_weight": True} # check other test estimators - assert not len(get_routing_for_object(ClassifierNoMetadata()).fit.requests) - assert_request_is_empty(ClassifierNoMetadata().get_metadata_routing()) + assert not len(get_routing_for_object(NonConsumingClassifier()).fit.requests) + assert_request_is_empty(NonConsumingClassifier().get_metadata_routing()) - trs_request = get_routing_for_object(TransformerMetadata()) + trs_request = get_routing_for_object(ConsumingTransformer()) assert trs_request.fit.requests == { "sample_weight": None, - "brand": None, + "metadata": None, } assert trs_request.transform.requests == { "sample_weight": None, } assert_request_is_empty(trs_request) - est_request = get_routing_for_object(ClassifierFitMetadata()) + est_request = get_routing_for_object(ConsumingClassifier()) assert est_request.fit.requests == { "sample_weight": None, - "brand": None, + "metadata": None, } assert_request_is_empty(est_request) def test_process_routing_invalid_method(): with pytest.raises(TypeError, match="Can only route and process input"): - process_routing(ClassifierFitMetadata(), "invalid_method", {}) + process_routing(ConsumingClassifier(), "invalid_method", {}) def test_process_routing_invalid_object(): @@ -421,52 +227,52 @@ def test_simple_metadata_routing(): # Tests that metadata is properly routed # The underlying estimator doesn't accept or request metadata - clf = SimpleMetaClassifier(estimator=ClassifierNoMetadata()) + clf = WeightedMetaClassifier(estimator=NonConsumingClassifier()) clf.fit(X, y) # Meta-estimator consumes sample_weight, but doesn't forward it to the underlying # estimator - clf = SimpleMetaClassifier(estimator=ClassifierNoMetadata()) + clf = WeightedMetaClassifier(estimator=NonConsumingClassifier()) clf.fit(X, y, sample_weight=my_weights) # If the estimator accepts the metadata but doesn't explicitly say it doesn't # need it, there's an error - clf = SimpleMetaClassifier(estimator=ClassifierFitMetadata()) + clf = WeightedMetaClassifier(estimator=ConsumingClassifier()) err_message = ( "[sample_weight] are passed but are not explicitly set as requested or" - " not for ClassifierFitMetadata.fit" + " not for ConsumingClassifier.fit" ) with pytest.raises(ValueError, match=re.escape(err_message)): clf.fit(X, y, sample_weight=my_weights) # Explicitly saying the estimator doesn't need it, makes the error go away, - # because in this case `SimpleMetaClassifier` consumes `sample_weight`. If + # because in this case `WeightedMetaClassifier` consumes `sample_weight`. If # there was no consumer of sample_weight, passing it would result in an # error. - clf = SimpleMetaClassifier( - estimator=ClassifierFitMetadata().set_fit_request(sample_weight=False) + clf = WeightedMetaClassifier( + estimator=ConsumingClassifier().set_fit_request(sample_weight=False) ) - # this doesn't raise since SimpleMetaClassifier itself is a consumer, + # this doesn't raise since WeightedMetaClassifier itself is a consumer, # and passing metadata to the consumer directly is fine regardless of its # metadata_request values. clf.fit(X, y, sample_weight=my_weights) - check_recorded_metadata(clf.estimator_, "fit", sample_weight=None, brand=None) + check_recorded_metadata(clf.estimator_, "fit") # Requesting a metadata will make the meta-estimator forward it correctly - clf = SimpleMetaClassifier( - estimator=ClassifierFitMetadata().set_fit_request(sample_weight=True) + clf = WeightedMetaClassifier( + estimator=ConsumingClassifier().set_fit_request(sample_weight=True) ) clf.fit(X, y, sample_weight=my_weights) - check_recorded_metadata(clf.estimator_, "fit", sample_weight=my_weights, brand=None) + check_recorded_metadata(clf.estimator_, "fit", sample_weight=my_weights) # And requesting it with an alias - clf = SimpleMetaClassifier( - estimator=ClassifierFitMetadata().set_fit_request( + clf = WeightedMetaClassifier( + estimator=ConsumingClassifier().set_fit_request( sample_weight="alternative_weight" ) ) clf.fit(X, y, alternative_weight=my_weights) - check_recorded_metadata(clf.estimator_, "fit", sample_weight=my_weights, brand=None) + check_recorded_metadata(clf.estimator_, "fit", sample_weight=my_weights) def test_nested_routing(): @@ -474,23 +280,23 @@ def test_nested_routing(): pipeline = SimplePipeline( [ MetaTransformer( - transformer=TransformerMetadata() - .set_fit_request(brand=True, sample_weight=False) + transformer=ConsumingTransformer() + .set_fit_request(metadata=True, sample_weight=False) .set_transform_request(sample_weight=True) ), WeightedMetaRegressor( - estimator=RegressorMetadata().set_fit_request( - sample_weight="inner_weights" - ) + estimator=ConsumingRegressor() + .set_fit_request(sample_weight="inner_weights", metadata=False) + .set_predict_request(sample_weight=False) ).set_fit_request(sample_weight="outer_weights"), ] ) w1, w2, w3 = [1], [2], [3] pipeline.fit( - X, y, brand=my_groups, sample_weight=w1, outer_weights=w2, inner_weights=w3 + X, y, metadata=my_groups, sample_weight=w1, outer_weights=w2, inner_weights=w3 ) check_recorded_metadata( - pipeline.steps_[0].transformer_, "fit", brand=my_groups, sample_weight=None + pipeline.steps_[0].transformer_, "fit", metadata=my_groups, sample_weight=None ) check_recorded_metadata( pipeline.steps_[0].transformer_, "transform", sample_weight=w1 @@ -509,12 +315,12 @@ def test_nested_routing_conflict(): pipeline = SimplePipeline( [ MetaTransformer( - transformer=TransformerMetadata() - .set_fit_request(brand=True, sample_weight=False) + transformer=ConsumingTransformer() + .set_fit_request(metadata=True, sample_weight=False) .set_transform_request(sample_weight=True) ), WeightedMetaRegressor( - estimator=RegressorMetadata().set_fit_request(sample_weight=True) + estimator=ConsumingRegressor().set_fit_request(sample_weight=True) ).set_fit_request(sample_weight="outer_weights"), ] ) @@ -530,13 +336,13 @@ def test_nested_routing_conflict(): ) ), ): - pipeline.fit(X, y, brand=my_groups, sample_weight=w1, outer_weights=w2) + pipeline.fit(X, y, metadata=my_groups, sample_weight=w1, outer_weights=w2) def test_invalid_metadata(): # check that passing wrong metadata raises an error trs = MetaTransformer( - transformer=TransformerMetadata().set_transform_request(sample_weight=True) + transformer=ConsumingTransformer().set_transform_request(sample_weight=True) ) with pytest.raises( TypeError, @@ -546,7 +352,7 @@ def test_invalid_metadata(): # passing a metadata which is not requested by any estimator should also raise trs = MetaTransformer( - transformer=TransformerMetadata().set_transform_request(sample_weight=False) + transformer=ConsumingTransformer().set_transform_request(sample_weight=False) ) with pytest.raises( TypeError, @@ -751,14 +557,14 @@ def test_metadata_router_consumes_method(): cases = [ ( WeightedMetaRegressor( - estimator=RegressorMetadata().set_fit_request(sample_weight=True) + estimator=ConsumingRegressor().set_fit_request(sample_weight=True) ), {"sample_weight"}, {"sample_weight"}, ), ( WeightedMetaRegressor( - estimator=RegressorMetadata().set_fit_request( + estimator=ConsumingRegressor().set_fit_request( sample_weight="my_weights" ) ), @@ -784,13 +590,13 @@ class WeightedMetaRegressorWarn(WeightedMetaRegressor): def test_estimator_warnings(): - class RegressorMetadataWarn(RegressorMetadata): + class ConsumingRegressorWarn(ConsumingRegressor): __metadata_request__fit = {"sample_weight": metadata_routing.WARN} with pytest.warns( UserWarning, match="Support for .* has recently been added to this class" ): - MetaRegressor(estimator=RegressorMetadataWarn()).fit( + MetaRegressor(estimator=ConsumingRegressorWarn()).fit( X, y, sample_weight=my_weights ) @@ -811,12 +617,14 @@ class RegressorMetadataWarn(RegressorMetadata): (MethodMapping.from_str("score"), "[{'callee': 'score', 'caller': 'score'}]"), ( MetadataRouter(owner="test").add( - method_mapping="predict", estimator=RegressorMetadata() + method_mapping="predict", estimator=ConsumingRegressor() ), ( - "{'estimator': {'mapping': [{'callee': 'predict', 'caller': " - "'predict'}], 'router': {'fit': {'sample_weight': None}, " - "'score': {'sample_weight': None}}}}" + "{'estimator': {'mapping': [{'callee': 'predict', 'caller':" + " 'predict'}], 'router': {'fit': {'sample_weight': None, 'metadata':" + " None}, 'partial_fit': {'sample_weight': None, 'metadata': None}," + " 'predict': {'sample_weight': None, 'metadata': None}, 'score':" + " {'sample_weight': None}}}}" ), ), ], @@ -857,7 +665,7 @@ def test_string_representations(obj, string): "Given `obj` is neither a `MetadataRequest` nor does it implement", ), ( - ClassifierFitMetadata(), + ConsumingClassifier(), "set_fit_request", {"invalid": True}, TypeError, @@ -900,14 +708,14 @@ def test_metadatarouter_add_self_request(): assert router._self_request is not request # one can add an estimator as self - est = RegressorMetadata().set_fit_request(sample_weight="my_weights") + est = ConsumingRegressor().set_fit_request(sample_weight="my_weights") router = MetadataRouter(owner="test").add_self_request(obj=est) assert str(router._self_request) == str(est.get_metadata_routing()) assert router._self_request is not est.get_metadata_routing() # adding a consumer+router as self should only add the consumer part est = WeightedMetaRegressor( - estimator=RegressorMetadata().set_fit_request(sample_weight="nested_weights") + estimator=ConsumingRegressor().set_fit_request(sample_weight="nested_weights") ) router = MetadataRouter(owner="test").add_self_request(obj=est) # _get_metadata_request() returns the consumer part of the requests @@ -923,25 +731,27 @@ def test_metadata_routing_add(): # adding one with a string `method_mapping` router = MetadataRouter(owner="test").add( method_mapping="fit", - est=RegressorMetadata().set_fit_request(sample_weight="weights"), + est=ConsumingRegressor().set_fit_request(sample_weight="weights"), ) assert ( str(router) - == "{'est': {'mapping': [{'callee': 'fit', 'caller': 'fit'}], " - "'router': {'fit': {'sample_weight': 'weights'}, 'score': " - "{'sample_weight': None}}}}" + == "{'est': {'mapping': [{'callee': 'fit', 'caller': 'fit'}], 'router': {'fit':" + " {'sample_weight': 'weights', 'metadata': None}, 'partial_fit':" + " {'sample_weight': None, 'metadata': None}, 'predict': {'sample_weight':" + " None, 'metadata': None}, 'score': {'sample_weight': None}}}}" ) # adding one with an instance of MethodMapping router = MetadataRouter(owner="test").add( method_mapping=MethodMapping().add(callee="score", caller="fit"), - est=RegressorMetadata().set_score_request(sample_weight=True), + est=ConsumingRegressor().set_score_request(sample_weight=True), ) assert ( str(router) - == "{'est': {'mapping': [{'callee': 'score', 'caller': 'fit'}], " - "'router': {'fit': {'sample_weight': None}, 'score': " - "{'sample_weight': True}}}}" + == "{'est': {'mapping': [{'callee': 'score', 'caller': 'fit'}], 'router':" + " {'fit': {'sample_weight': None, 'metadata': None}, 'partial_fit':" + " {'sample_weight': None, 'metadata': None}, 'predict': {'sample_weight':" + " None, 'metadata': None}, 'score': {'sample_weight': True}}}}" ) @@ -949,13 +759,13 @@ def test_metadata_routing_get_param_names(): router = ( MetadataRouter(owner="test") .add_self_request( - WeightedMetaRegressor(estimator=RegressorMetadata()).set_fit_request( + WeightedMetaRegressor(estimator=ConsumingRegressor()).set_fit_request( sample_weight="self_weights" ) ) .add( method_mapping="fit", - trs=TransformerMetadata().set_fit_request( + trs=ConsumingTransformer().set_fit_request( sample_weight="transform_weights" ), ) @@ -963,24 +773,23 @@ def test_metadata_routing_get_param_names(): assert ( str(router) - == "{'$self_request': {'fit': {'sample_weight': 'self_weights'}, 'score': " - "{'sample_weight': None}}, 'trs': {'mapping': [{'callee': 'fit', " - "'caller': 'fit'}], 'router': {'fit': {'brand': None, " - "'sample_weight': 'transform_weights'}, 'transform': " - "{'sample_weight': None}}}}" + == "{'$self_request': {'fit': {'sample_weight': 'self_weights'}, 'score':" + " {'sample_weight': None}}, 'trs': {'mapping': [{'callee': 'fit', 'caller':" + " 'fit'}], 'router': {'fit': {'sample_weight': 'transform_weights'," + " 'metadata': None}, 'transform': {'sample_weight': None}}}}" ) assert router._get_param_names( method="fit", return_alias=True, ignore_self_request=False - ) == {"transform_weights", "brand", "self_weights"} + ) == {"transform_weights", "metadata", "self_weights"} # return_alias=False will return original names for "self" assert router._get_param_names( method="fit", return_alias=False, ignore_self_request=False - ) == {"sample_weight", "brand", "transform_weights"} + ) == {"sample_weight", "metadata", "transform_weights"} # ignoring self would remove "sample_weight" assert router._get_param_names( method="fit", return_alias=False, ignore_self_request=True - ) == {"brand", "transform_weights"} + ) == {"metadata", "transform_weights"} # return_alias is ignored when ignore_self_request=True assert router._get_param_names( method="fit", return_alias=True, ignore_self_request=True @@ -1138,9 +947,9 @@ def test_no_feature_flag_raises_error(): """Test that when feature flag disabled, set_{method}_requests raises.""" with config_context(enable_metadata_routing=False): with pytest.raises(RuntimeError, match="This method is only available"): - ClassifierFitMetadata().set_fit_request(sample_weight=True) + ConsumingClassifier().set_fit_request(sample_weight=True) def test_none_metadata_passed(): """Test that passing None as metadata when not requested doesn't raise""" - MetaRegressor(estimator=RegressorMetadata()).fit(X, y, sample_weight=None) + MetaRegressor(estimator=ConsumingRegressor()).fit(X, y, sample_weight=None) diff --git a/sklearn/tests/test_metaestimators_metadata_routing.py b/sklearn/tests/test_metaestimators_metadata_routing.py index 768a57c61dc52..42b869a06b70d 100644 --- a/sklearn/tests/test_metaestimators_metadata_routing.py +++ b/sklearn/tests/test_metaestimators_metadata_routing.py @@ -1,28 +1,29 @@ import copy import re -from functools import partial import numpy as np import pytest from sklearn import config_context -from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin from sklearn.calibration import CalibratedClassifierCV from sklearn.exceptions import UnsetMetadataPassedError from sklearn.linear_model import LogisticRegressionCV -from sklearn.metrics._scorer import _BaseScorer -from sklearn.model_selection import BaseCrossValidator -from sklearn.model_selection._split import GroupsConsumerMixin from sklearn.multioutput import ( ClassifierChain, MultiOutputClassifier, MultiOutputRegressor, RegressorChain, ) +from sklearn.tests.metadata_routing_common import ( + ConsumingClassifier, + ConsumingRegressor, + ConsumingScorer, + ConsumingSplitter, + _Registry, +) from sklearn.tests.test_metadata_routing import ( assert_request_is_empty, check_recorded_metadata, - record_metadata, ) from sklearn.utils.metadata_routing import MetadataRouter @@ -43,179 +44,6 @@ def enable_slep006(): yield -record_metadata_not_default = partial(record_metadata, record_default=False) - - -class _Registry(list): - # This list is used to get a reference to the sub-estimators, which are not - # necessarily stored on the metaestimator. We need to override __deepcopy__ - # because the sub-estimators are probably cloned, which would result in a - # new copy of the list, but we need copy and deep copy both to return the - # same instance. - def __deepcopy__(self, memo): - return self - - def __copy__(self): - return self - - -class ConsumingRegressor(RegressorMixin, BaseEstimator): - """A regressor consuming metadata. - - Parameters - ---------- - registry : list, default=None - If a list, the estimator will append itself to the list in order to have - a reference to the estimator later on. Since that reference is not - required in all tests, registration can be skipped by leaving this value - as None. - - """ - - def __init__(self, registry=None): - self.registry = registry - - def partial_fit(self, X, y, sample_weight="default", metadata="default"): - if self.registry is not None: - self.registry.append(self) - - record_metadata_not_default( - self, "partial_fit", sample_weight=sample_weight, metadata=metadata - ) - return self - - def fit(self, X, y, sample_weight="default", metadata="default"): - if self.registry is not None: - self.registry.append(self) - - record_metadata_not_default( - self, "fit", sample_weight=sample_weight, metadata=metadata - ) - return self - - def predict(self, X, sample_weight="default", metadata="default"): - pass # pragma: no cover - - # when needed, uncomment the implementation - # if self.registry is not None: - # self.registry.append(self) - - # record_metadata_not_default( - # self, "predict", sample_weight=sample_weight, metadata=metadata - # ) - # return np.zeros(shape=(len(X),)) - - -class ConsumingClassifier(ClassifierMixin, BaseEstimator): - """A classifier consuming metadata. - - Parameters - ---------- - registry : list, default=None - If a list, the estimator will append itself to the list in order to have - a reference to the estimator later on. Since that reference is not - required in all tests, registration can be skipped by leaving this value - as None. - - """ - - def __init__(self, registry=None): - self.registry = registry - - def partial_fit(self, X, y, sample_weight="default", metadata="default"): - if self.registry is not None: - self.registry.append(self) - - record_metadata_not_default( - self, "partial_fit", sample_weight=sample_weight, metadata=metadata - ) - self.classes_ = [0, 1] - return self - - def fit(self, X, y, sample_weight="default", metadata="default"): - if self.registry is not None: - self.registry.append(self) - - record_metadata_not_default( - self, "fit", sample_weight=sample_weight, metadata=metadata - ) - self.classes_ = [0, 1] - return self - - def predict(self, X, sample_weight="default", metadata="default"): - pass # pragma: no cover - - # when needed, uncomment the implementation - # if self.registry is not None: - # self.registry.append(self) - - # record_metadata_not_default( - # self, "predict", sample_weight=sample_weight, metadata=metadata - # ) - # return np.zeros(shape=(len(X),)) - - def predict_proba(self, X, sample_weight="default", metadata="default"): - if self.registry is not None: - self.registry.append(self) - - record_metadata_not_default( - self, "predict_proba", sample_weight=sample_weight, metadata=metadata - ) - return np.asarray([[0.0, 1.0]] * len(X)) - - def predict_log_proba(self, X, sample_weight="default", metadata="default"): - pass # pragma: no cover - - # when needed, uncomment the implementation - # if self.registry is not None: - # self.registry.append(self) - - # record_metadata_not_default( - # self, "predict_log_proba", sample_weight=sample_weight, metadata=metadata - # ) - # return np.zeros(shape=(len(X), 2)) - - -class ConsumingScorer(_BaseScorer): - def __init__(self, registry=None): - super().__init__(score_func="test", sign=1, kwargs={}) - self.registry = registry - - def __call__( - self, estimator, X, y_true, sample_weight="default", metadata="default" - ): - if self.registry is not None: - self.registry.append(self) - - record_metadata_not_default( - self, "score", sample_weight=sample_weight, metadata=metadata - ) - - return 0.0 - - -class ConsumingSplitter(BaseCrossValidator, GroupsConsumerMixin): - def __init__(self, registry=None): - self.registry = registry - - def split(self, X, y=None, groups="default"): - if self.registry is not None: - self.registry.append(self) - - record_metadata_not_default(self, "split", groups=groups) - - split_index = len(X) - 10 - train_indices = range(0, split_index) - test_indices = range(split_index, len(X)) - yield test_indices, train_indices - - def get_n_splits(self, X=None, y=None, groups=None): - pass # pragma: no cover - - def _iter_test_indices(self, X=None, y=None, groups=None): - pass # pragma: no cover - - METAESTIMATORS: list = [ { "metaestimator": MultiOutputRegressor, From 51486f523bb992d0ba0065ad29c2bb7f16dd9b3b Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Tue, 25 Jul 2023 17:09:02 +0200 Subject: [PATCH 2/2] add registry to all consumers and test --- sklearn/tests/metadata_routing_common.py | 18 ++++++++++++++++-- sklearn/tests/test_metadata_routing.py | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/sklearn/tests/metadata_routing_common.py b/sklearn/tests/metadata_routing_common.py index fdfc235f6ad04..6018a7da9a4c3 100644 --- a/sklearn/tests/metadata_routing_common.py +++ b/sklearn/tests/metadata_routing_common.py @@ -163,7 +163,13 @@ def predict(self, X, sample_weight="default", metadata="default"): class NonConsumingClassifier(ClassifierMixin, BaseEstimator): """A classifier which accepts no metadata on any method.""" + def __init__(self, registry=None): + self.registry = registry + def fit(self, X, y): + if self.registry is not None: + self.registry.append(self) + self.classes_ = [0, 1] return self @@ -330,10 +336,14 @@ def get_metadata_routing(self): class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): """A meta-regressor which is also a consumer.""" - def __init__(self, estimator): + def __init__(self, estimator, registry=None): self.estimator = estimator + self.registry = registry def fit(self, X, y, sample_weight=None, **fit_params): + if self.registry is not None: + self.registry.append(self) + record_metadata(self, "fit", sample_weight=sample_weight) params = process_routing(self, "fit", fit_params, sample_weight=sample_weight) self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) @@ -355,10 +365,14 @@ def get_metadata_routing(self): class WeightedMetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): """A meta-estimator which also consumes sample_weight itself in ``fit``.""" - def __init__(self, estimator): + def __init__(self, estimator, registry=None): self.estimator = estimator + self.registry = registry def fit(self, X, y, sample_weight=None, **kwargs): + if self.registry is not None: + self.registry.append(self) + record_metadata(self, "fit", sample_weight=sample_weight) params = process_routing(self, "fit", kwargs, sample_weight=sample_weight) self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) diff --git a/sklearn/tests/test_metadata_routing.py b/sklearn/tests/test_metadata_routing.py index 83bfbfc012780..43edcddaf3add 100644 --- a/sklearn/tests/test_metadata_routing.py +++ b/sklearn/tests/test_metadata_routing.py @@ -25,6 +25,7 @@ NonConsumingClassifier, WeightedMetaClassifier, WeightedMetaRegressor, + _Registry, assert_request_equal, assert_request_is_empty, check_recorded_metadata, @@ -144,6 +145,23 @@ def test_assert_request_is_empty(): ) +@pytest.mark.parametrize( + "estimator", + [ + ConsumingClassifier(registry=_Registry()), + ConsumingRegressor(registry=_Registry()), + ConsumingTransformer(registry=_Registry()), + NonConsumingClassifier(registry=_Registry()), + WeightedMetaClassifier(estimator=ConsumingClassifier(), registry=_Registry()), + WeightedMetaRegressor(estimator=ConsumingRegressor(), registry=_Registry()), + ], +) +def test_estimator_puts_self_in_registry(estimator): + """Check that an estimator puts itself in the registry upon fit.""" + estimator.fit(X, y) + assert estimator in estimator.registry + + @pytest.mark.parametrize( "val, res", [