Description
Describe the workflow you want to enable
I would like to pass sample properties to the response method (eg predict
) called by a scorer.
For example, the fairlearn
package has a ThresholdOptimizer
estimator which needs (in addition to X and y) the sensitive_features
argument both for fit and predict.
AFAICT I can pass arguments to the score function (the metric), but not to the response method of the estimator.
import numpy as np
import sklearn
from sklearn.dummy import DummyClassifier
from sklearn.metrics import accuracy_score, make_scorer
from fairlearn.postprocessing import ThresholdOptimizer
from fairlearn.metrics import demographic_parity_difference
sklearn.set_config(enable_metadata_routing=True)
rng = np.random.default_rng(0)
X = rng.normal(size=(10, 3))
y = rng.integers(0, 2, size=X.shape[0])
sensitive = rng.integers(0, 2, size=X.shape[0])
classifier = (
ThresholdOptimizer(estimator=DummyClassifier(), predict_method="auto")
.set_fit_request(sensitive_features=True)
.set_predict_request(sensitive_features=True)
.fit(X, y, sensitive_features=sensitive)
)
scoring = make_scorer(accuracy_score)
scoring(classifier, X, y, sensitive_features=sensitive) # TypeError: predict() missing 1 argument -- how could I pass `sensitive_features to predict() ?
# passing arguments to the score function (demographic_parity_difference) is OK
classifier = DummyClassifier().fit(X, y)
scoring = make_scorer(
demographic_parity_difference, greater_is_better=False
).set_score_request(sensitive_features=True)
scoring(classifier, X, y, sensitive_features=sensitive)
This also applies when using a scorer indirectly, for example in cross_validate
Describe your proposed solution
Maybe the scorers could have a method like set_predict_request
or set_response_request
to specify which parameters should be forwarded to the response method?
Describe alternatives you've considered, if relevant
No response
Additional context
https://fairlearn.org/v0.9/api_reference/generated/fairlearn.postprocessing.ThresholdOptimizer.html
https://fairlearn.org/v0.9/api_reference/generated/fairlearn.metrics.demographic_parity_difference.html
Metadata
Metadata
Assignees
Type
Projects
Status