Thanks to visit codestin.com
Credit goes to github.com

Skip to content

TST improve metadata routing tests #29226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
868d0ff
FEAT allow metadata to be transformed in Pipeline
adrinjalali Apr 15, 2024
42dfe81
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali Apr 26, 2024
94c8bd9
add tests
adrinjalali Apr 26, 2024
818da32
add fit_transform
adrinjalali Apr 26, 2024
067946c
fix pprint test
adrinjalali Apr 29, 2024
ed5edcd
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali May 7, 2024
85c10a4
add changelog
adrinjalali May 7, 2024
ad269ea
much more extensive tests
adrinjalali May 8, 2024
1622203
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali May 8, 2024
5268514
more fixes
adrinjalali May 24, 2024
1a4a428
Merge remote-tracking branch 'upstream/main' into pipeline/transform
adrinjalali May 25, 2024
052b13d
WIP tests improvements
adrinjalali May 26, 2024
278dc70
TST fix pipeline tests
adrinjalali May 26, 2024
0716f51
revert pipeline's transform_input
adrinjalali Jun 10, 2024
d0724e1
TST remove print
adrinjalali Jun 10, 2024
19afa1c
Merge remote-tracking branch 'upstream/main' into slep6/tests
adrinjalali Jun 10, 2024
b563cbb
TST fix tests
adrinjalali Jun 10, 2024
b926dbc
Update sklearn/tests/metadata_routing_common.py
adrinjalali Jun 10, 2024
9d44784
Update sklearn/tests/metadata_routing_common.py
adrinjalali Jun 10, 2024
ded64ce
Update sklearn/tests/metadata_routing_common.py
adrinjalali Jun 10, 2024
4482191
Update sklearn/tests/metadata_routing_common.py
adrinjalali Jun 11, 2024
b343a04
Merge remote-tracking branch 'upstream/main' into slep6/tests
adrinjalali Jun 11, 2024
ee64d03
Guillaume's comments
adrinjalali Jun 11, 2024
a1023b3
Address Omar's comment
adrinjalali Jun 14, 2024
7b75a8e
Merge remote-tracking branch 'upstream/main' into slep6/tests
adrinjalali Jun 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions sklearn/compose/tests/test_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2640,15 +2640,19 @@ def test_metadata_routing_for_column_transformer(method):
)

if method == "transform":
trs.fit(X, y)
trs.fit(X, y, sample_weight=sample_weight, metadata=metadata)
trs.transform(X, sample_weight=sample_weight, metadata=metadata)
else:
getattr(trs, method)(X, y, sample_weight=sample_weight, metadata=metadata)

assert len(registry)
for _trs in registry:
check_recorded_metadata(
obj=_trs, method=method, sample_weight=sample_weight, metadata=metadata
obj=_trs,
method=method,
parent=method,
sample_weight=sample_weight,
metadata=metadata,
)


Expand Down
12 changes: 10 additions & 2 deletions sklearn/ensemble/tests/test_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,13 +973,21 @@ def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop, prop_v
assert len(registry)
for sub_est in registry:
check_recorded_metadata(
obj=sub_est, method="fit", split_params=(prop), **{prop: prop_value}
obj=sub_est,
method="fit",
parent="fit",
split_params=(prop),
**{prop: prop_value},
)
# access final_estimator:
registry = est.final_estimator_.registry
assert len(registry)
check_recorded_metadata(
obj=registry[-1], method="predict", split_params=(prop), **{prop: prop_value}
obj=registry[-1],
method="predict",
parent="predict",
split_params=(prop),
**{prop: prop_value},
)


Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/tests/test_voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def test_metadata_routing_for_voting_estimators(Estimator, Child, prop):
registry = estimator[1].registry
assert len(registry)
for sub_est in registry:
check_recorded_metadata(obj=sub_est, method="fit", **kwargs)
check_recorded_metadata(obj=sub_est, method="fit", parent="fit", **kwargs)


@pytest.mark.usefixtures("enable_slep006")
Expand Down
1 change: 1 addition & 0 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2614,6 +2614,7 @@ def test_multi_metric_search_forwards_metadata(SearchCV, param_search):
check_recorded_metadata(
obj=_scorer,
method="score",
parent="_score",
split_params=("sample_weight", "metadata"),
sample_weight=score_weights,
metadata=score_metadata,
Expand Down
4 changes: 4 additions & 0 deletions sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2601,6 +2601,7 @@ def test_validation_functions_routing(func):
check_recorded_metadata(
obj=_scorer,
method="score",
parent=func.__name__,
split_params=("sample_weight", "metadata"),
sample_weight=score_weights,
metadata=score_metadata,
Expand All @@ -2611,6 +2612,7 @@ def test_validation_functions_routing(func):
check_recorded_metadata(
obj=_splitter,
method="split",
parent=func.__name__,
groups=split_groups,
metadata=split_metadata,
)
Expand All @@ -2620,6 +2622,7 @@ def test_validation_functions_routing(func):
check_recorded_metadata(
obj=_estimator,
method="fit",
parent=func.__name__,
split_params=("sample_weight", "metadata"),
sample_weight=fit_sample_weight,
metadata=fit_metadata,
Expand Down Expand Up @@ -2657,6 +2660,7 @@ def test_learning_curve_exploit_incremental_learning_routing():
check_recorded_metadata(
obj=_estimator,
method="partial_fit",
parent="learning_curve",
split_params=("sample_weight", "metadata"),
sample_weight=fit_sample_weight,
metadata=fit_metadata,
Expand Down
116 changes: 65 additions & 51 deletions sklearn/tests/metadata_routing_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect
from collections import defaultdict
from functools import partial

import numpy as np
Expand Down Expand Up @@ -25,55 +27,69 @@
from sklearn.utils.multiclass import _check_partial_fit_first_call


def record_metadata(obj, method, record_default=True, **kwargs):
"""Utility function to store passed metadata to a method.
def record_metadata(obj, record_default=True, **kwargs):
"""Utility function to store passed metadata to a method of obj.

If record_default is False, kwargs whose values are "default" are skipped.
This is so that checks on keyword arguments whose default was not changed
are skipped.

"""
stack = inspect.stack()
callee = stack[1].function
caller = stack[2].function
if not hasattr(obj, "_records"):
obj._records = {}
obj._records = defaultdict(lambda: defaultdict(list))
if not record_default:
kwargs = {
key: val
for key, val in kwargs.items()
if not isinstance(val, str) or (val != "default")
}
obj._records[method] = kwargs
obj._records[callee][caller].append(kwargs)


def check_recorded_metadata(obj, method, split_params=tuple(), **kwargs):
def check_recorded_metadata(obj, method, parent, split_params=tuple(), **kwargs):
"""Check whether the expected metadata is passed to the object's method.

Parameters
----------
obj : estimator object
sub-estimator to check routed params for
method : str
sub-estimator's method where metadata is routed to
sub-estimator's method where metadata is routed to, or otherwise in
the context of metadata routing referred to as 'callee'
parent : str
the parent method which should have called `method`, or otherwise in
the context of metadata routing referred to as 'caller'
split_params : tuple, default=empty
specifies any parameters which are to be checked as being a subset
of the original values
**kwargs : dict
passed metadata
"""
records = getattr(obj, "_records", dict()).get(method, dict())
assert set(kwargs.keys()) == set(
records.keys()
), f"Expected {kwargs.keys()} vs {records.keys()}"
for key, value in kwargs.items():
recorded_value = records[key]
# The following condition is used to check for any specified parameters
# being a subset of the original values
if key in split_params and recorded_value is not None:
assert np.isin(recorded_value, value).all()
else:
if isinstance(recorded_value, np.ndarray):
assert_array_equal(recorded_value, value)
all_records = (
getattr(obj, "_records", dict()).get(method, dict()).get(parent, list())
)
for record in all_records:
# first check that the names of the metadata passed are the same as
# expected. The names are stored as keys in `record`.
assert set(kwargs.keys()) == set(
record.keys()
), f"Expected {kwargs.keys()} vs {record.keys()}"
for key, value in kwargs.items():
recorded_value = record[key]
# The following condition is used to check for any specified parameters
# being a subset of the original values
if key in split_params and recorded_value is not None:
assert np.isin(recorded_value, value).all()
else:
assert recorded_value is value, f"Expected {recorded_value} vs {value}"
if isinstance(recorded_value, np.ndarray):
assert_array_equal(recorded_value, value)
else:
assert (
recorded_value is value
), f"Expected {recorded_value} vs {value}. Method: {method}"


record_metadata_not_default = partial(record_metadata, record_default=False)
Expand Down Expand Up @@ -151,7 +167,7 @@ def partial_fit(self, X, y, sample_weight="default", metadata="default"):
self.registry.append(self)

record_metadata_not_default(
self, "partial_fit", sample_weight=sample_weight, metadata=metadata
self, sample_weight=sample_weight, metadata=metadata
)
return self

Expand All @@ -160,19 +176,19 @@ def fit(self, X, y, sample_weight="default", metadata="default"):
self.registry.append(self)

record_metadata_not_default(
self, "fit", sample_weight=sample_weight, metadata=metadata
self, sample_weight=sample_weight, metadata=metadata
)
return self

def predict(self, X, y=None, sample_weight="default", metadata="default"):
record_metadata_not_default(
self, "predict", sample_weight=sample_weight, metadata=metadata
self, sample_weight=sample_weight, metadata=metadata
)
return np.zeros(shape=(len(X),))

def score(self, X, y, sample_weight="default", metadata="default"):
record_metadata_not_default(
self, "score", sample_weight=sample_weight, metadata=metadata
self, sample_weight=sample_weight, metadata=metadata
)
return 1

Expand Down Expand Up @@ -240,7 +256,7 @@ def partial_fit(
self.registry.append(self)

record_metadata_not_default(
self, "partial_fit", sample_weight=sample_weight, metadata=metadata
self, sample_weight=sample_weight, metadata=metadata
)
_check_partial_fit_first_call(self, classes)
return self
Expand All @@ -250,15 +266,15 @@ def fit(self, X, y, sample_weight="default", metadata="default"):
self.registry.append(self)

record_metadata_not_default(
self, "fit", sample_weight=sample_weight, metadata=metadata
self, sample_weight=sample_weight, metadata=metadata
)

self.classes_ = np.unique(y)
return self

def predict(self, X, sample_weight="default", metadata="default"):
record_metadata_not_default(
self, "predict", sample_weight=sample_weight, metadata=metadata
self, sample_weight=sample_weight, metadata=metadata
)
y_score = np.empty(shape=(len(X),), dtype="int8")
y_score[len(X) // 2 :] = 0
Expand All @@ -267,7 +283,7 @@ def predict(self, X, sample_weight="default", metadata="default"):

def predict_proba(self, X, sample_weight="default", metadata="default"):
record_metadata_not_default(
self, "predict_proba", sample_weight=sample_weight, metadata=metadata
self, sample_weight=sample_weight, metadata=metadata
)
y_proba = np.empty(shape=(len(X), 2))
y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0])
Expand All @@ -279,13 +295,13 @@ def predict_log_proba(self, X, sample_weight="default", metadata="default"):

# uncomment when needed
# record_metadata_not_default(
# self, "predict_log_proba", sample_weight=sample_weight, metadata=metadata
# self, sample_weight=sample_weight, metadata=metadata
# )
# return np.zeros(shape=(len(X), 2))

def decision_function(self, X, sample_weight="default", metadata="default"):
record_metadata_not_default(
self, "predict_proba", sample_weight=sample_weight, metadata=metadata
self, sample_weight=sample_weight, metadata=metadata
)
y_score = np.empty(shape=(len(X),))
y_score[len(X) // 2 :] = 0
Expand All @@ -295,7 +311,7 @@ def decision_function(self, X, sample_weight="default", metadata="default"):
# uncomment when needed
# def score(self, X, y, sample_weight="default", metadata="default"):
# record_metadata_not_default(
# self, "score", sample_weight=sample_weight, metadata=metadata
# self, sample_weight=sample_weight, metadata=metadata
# )
# return 1

Expand All @@ -315,38 +331,38 @@ class ConsumingTransformer(TransformerMixin, BaseEstimator):
def __init__(self, registry=None):
self.registry = registry

def fit(self, X, y=None, sample_weight=None, metadata=None):
def fit(self, X, y=None, sample_weight="default", metadata="default"):
if self.registry is not None:
self.registry.append(self)

record_metadata_not_default(
self, "fit", sample_weight=sample_weight, metadata=metadata
self, sample_weight=sample_weight, metadata=metadata
)
return self

def transform(self, X, sample_weight=None, metadata=None):
record_metadata(
self, "transform", sample_weight=sample_weight, metadata=metadata
def transform(self, X, sample_weight="default", metadata="default"):
record_metadata_not_default(
self, sample_weight=sample_weight, metadata=metadata
)
return X
return X + 1

def fit_transform(self, X, y, sample_weight=None, metadata=None):
def fit_transform(self, X, y, sample_weight="default", metadata="default"):
# implementing ``fit_transform`` is necessary since
# ``TransformerMixin.fit_transform`` doesn't route any metadata to
# ``transform``, while here we want ``transform`` to receive
# ``sample_weight`` and ``metadata``.
record_metadata(
self, "fit_transform", sample_weight=sample_weight, metadata=metadata
record_metadata_not_default(
self, sample_weight=sample_weight, metadata=metadata
)
return self.fit(X, y, sample_weight=sample_weight, metadata=metadata).transform(
X, sample_weight=sample_weight, metadata=metadata
)

def inverse_transform(self, X, sample_weight=None, metadata=None):
record_metadata(
self, "inverse_transform", sample_weight=sample_weight, metadata=metadata
record_metadata_not_default(
self, sample_weight=sample_weight, metadata=metadata
)
return X
return X - 1


class ConsumingNoFitTransformTransformer(BaseEstimator):
Expand All @@ -361,14 +377,12 @@ def fit(self, X, y=None, sample_weight=None, metadata=None):
if self.registry is not None:
self.registry.append(self)

record_metadata(self, "fit", sample_weight=sample_weight, metadata=metadata)
record_metadata(self, sample_weight=sample_weight, metadata=metadata)

return self

def transform(self, X, sample_weight=None, metadata=None):
record_metadata(
self, "transform", sample_weight=sample_weight, metadata=metadata
)
record_metadata(self, sample_weight=sample_weight, metadata=metadata)
return X


Expand All @@ -383,7 +397,7 @@ def _score(self, method_caller, clf, X, y, **kwargs):
if self.registry is not None:
self.registry.append(self)

record_metadata_not_default(self, "score", **kwargs)
record_metadata_not_default(self, **kwargs)

sample_weight = kwargs.get("sample_weight", None)
return super()._score(method_caller, clf, X, y, sample_weight=sample_weight)
Expand All @@ -397,7 +411,7 @@ def split(self, X, y=None, groups="default", metadata="default"):
if self.registry is not None:
self.registry.append(self)

record_metadata_not_default(self, "split", groups=groups, metadata=metadata)
record_metadata_not_default(self, groups=groups, metadata=metadata)

split_index = len(X) // 2
train_indices = list(range(0, split_index))
Expand Down Expand Up @@ -445,7 +459,7 @@ def fit(self, X, y, sample_weight=None, **fit_params):
if self.registry is not None:
self.registry.append(self)

record_metadata(self, "fit", sample_weight=sample_weight)
record_metadata(self, sample_weight=sample_weight)
params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
return self
Expand Down Expand Up @@ -479,7 +493,7 @@ def fit(self, X, y, sample_weight=None, **kwargs):
if self.registry is not None:
self.registry.append(self)

record_metadata(self, "fit", sample_weight=sample_weight)
record_metadata(self, sample_weight=sample_weight)
params = process_routing(self, "fit", sample_weight=sample_weight, **kwargs)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
return self
Expand Down
Loading