Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Meta-estimator will ignore sample_weight when a Pipeline is passed #21134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
glemaitre opened this issue Sep 24, 2021 · 1 comment · Fixed by #26789
Closed

Meta-estimator will ignore sample_weight when a Pipeline is passed #21134

glemaitre opened this issue Sep 24, 2021 · 1 comment · Fixed by #26789
Labels

Comments

@glemaitre
Copy link
Member

Describe the bug

Related to #18159

While working on #20610, I discovered that we have a silent bug with meta-estimator that uses check the signature of fit to know if they should pass sample_weight. Indeed, Pipeline does require a fit_params where weights can be passed to a specific estimator.

However, the previous simple check will thus fail. In some meta-estimator, we will raise an error like in BaggingClassifier, however, in CalibrationClassifierCV, the weights will silently be ignored that is even worse.

Steps/Code to Reproduce

# %%
import numpy as np
from sklearn.datasets import make_classification

class_weights = np.array([0.1, 0.9])
X, y = make_classification(n_samples=1_000, weights=class_weights, random_state=42)

# %%
from collections import Counter

Counter(y)

# %%
sample_weights = class_weights[(y == 1).astype(int)]
sample_weights

# %%
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test, sample_weights_train, sample_weights_test = \
    train_test_split(X, y, sample_weights, random_state=42)

# %%
from sklearn.ensemble import BaggingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import balanced_accuracy_score

model = BaggingClassifier(base_estimator=DecisionTreeClassifier(), random_state=42)
model.fit(X_train, y_train, sample_weight=sample_weights_train)
y_pred = model.predict(X_test)
Counter(y_pred)

# %%
balanced_accuracy_score(y_test, y_pred)

# %%
from sklearn.pipeline import make_pipeline

# Simulate that we wrap the predictor in a Pipeline
model = BaggingClassifier(make_pipeline(DecisionTreeClassifier()), random_state=42)
model.fit(X_train, y_train, sample_weight=sample_weights_train)
y_pred = model.predict(X_test)
Counter(y_pred)

# %%
balanced_accuracy_score(y_test, y_pred)

Expected Results

Both examples should work identically.

Actual Results

ValueError: The base estimator doesn't support sample weight
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/Documents/scratch/issue_sample_weights.py in <module>
      39 # Simulate that we wrap the predictor in a Pipeline
      40 model = BaggingClassifier(make_pipeline(DecisionTreeClassifier()), random_state=42)
----> 41 model.fit(X_train, y_train, sample_weight=sample_weights_train)
      42 y_pred = model.predict(X_test)
      43 Counter(y_pred)

~/Documents/packages/scikit-learn/sklearn/ensemble/_bagging.py in fit(self, X, y, sample_weight)
    258             Fitted estimator.
    259         """
--> 260         return self._fit(X, y, self.max_samples, sample_weight=sample_weight)
    261 
    262     def _parallel_args(self):

~/Documents/packages/scikit-learn/sklearn/ensemble/_bagging.py in _fit(self, X, y, max_samples, max_depth, sample_weight)
    392         self._seeds = seeds
    393 
--> 394         all_results = Parallel(
    395             n_jobs=n_jobs, verbose=self.verbose, **self._parallel_args()
    396         )(

~/Documents/packages/joblib/joblib/parallel.py in __call__(self, iterable)
   1039             # remaining jobs.
   1040             self._iterating = False
-> 1041             if self.dispatch_one_batch(iterator):
   1042                 self._iterating = self._original_iterator is not None
   1043 

~/Documents/packages/joblib/joblib/parallel.py in dispatch_one_batch(self, iterator)
    857                 return False
    858             else:
--> 859                 self._dispatch(tasks)
    860                 return True
    861 

~/Documents/packages/joblib/joblib/parallel.py in _dispatch(self, batch)
    775         with self._lock:
    776             job_idx = len(self._jobs)
--> 777             job = self._backend.apply_async(batch, callback=cb)
    778             # A job can complete so quickly than its callback is
    779             # called before we get here, causing self._jobs to

~/Documents/packages/joblib/joblib/_parallel_backends.py in apply_async(self, func, callback)
    206     def apply_async(self, func, callback=None):
    207         """Schedule a func to be run"""
--> 208         result = ImmediateResult(func)
    209         if callback:
    210             callback(result)

~/Documents/packages/joblib/joblib/_parallel_backends.py in __init__(self, batch)
    570         # Don't delay the application, to avoid keeping the input
    571         # arguments in memory
--> 572         self.results = batch()
    573 
    574     def get(self):

~/Documents/packages/joblib/joblib/parallel.py in __call__(self)
    260         # change the default number of processes to -1
    261         with parallel_backend(self._backend, n_jobs=self._n_jobs):
--> 262             return [func(*args, **kwargs)
    263                     for func, args, kwargs in self.items]
    264 

~/Documents/packages/joblib/joblib/parallel.py in <listcomp>(.0)
    260         # change the default number of processes to -1
    261         with parallel_backend(self._backend, n_jobs=self._n_jobs):
--> 262             return [func(*args, **kwargs)
    263                     for func, args, kwargs in self.items]
    264 

~/Documents/packages/scikit-learn/sklearn/utils/fixes.py in __call__(self, *args, **kwargs)
    207     def __call__(self, *args, **kwargs):
    208         with config_context(**self.config):
--> 209             return self.function(*args, **kwargs)
    210 
    211 

~/Documents/packages/scikit-learn/sklearn/ensemble/_bagging.py in _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight, seeds, total_n_estimators, verbose)
     80     support_sample_weight = has_fit_parameter(ensemble.base_estimator_, "sample_weight")
     81     if not support_sample_weight and sample_weight is not None:
---> 82         raise ValueError("The base estimator doesn't support sample weight")
     83 
     84     # Build estimators

ValueError: The base estimator doesn't support sample weight

Versions

System:
    python: 3.8.12 | packaged by conda-forge | (default, Sep 16 2021, 01:38:21)  [Clang 11.1.0 ]
executable: /Users/glemaitre/mambaforge/envs/dev/bin/python
   machine: macOS-11.6-arm64-arm-64bit

Python dependencies:
          pip: 21.2.4
   setuptools: 58.0.4
      sklearn: 1.1.dev0
        numpy: 1.21.2
        scipy: 1.7.1
       Cython: 0.29.24
       pandas: 1.3.3
   matplotlib: 3.4.3
       joblib: 1.0.1
threadpoolctl: 2.2.0

Built with OpenMP: True
@ogrisel
Copy link
Member

ogrisel commented Sep 30, 2021

We should that this problem can be fixed by the prototype implementation #20350 for SLEP006 on meta-data routing.

BenjaminBossan added a commit to BenjaminBossan/scikit-learn that referenced this issue Aug 5, 2022
This PR adds metadata routing to CalibratedClassifierCV (CCV). CCV uses
a subestimator to create (out of sample) probabilities, which are in
turn used to calibrate the probabilities.

The metaestimator uses sample_weight. The subestimator may or may not
use sample_weight and additional metadata. So far, it was checked if the
subestimator has sample_weight in its signature and then they were
routed, otherwise not. This is, however, not always ideal, e.g. when the
subestimator is itself a
pipeline (scikit-learn#21134).
With routing, this problem disappears.

In addition to these changes, the tests in
test_metaestimator_metadata_routing.py have been amended to make them
more generic, as right now, they are specific to multioutput.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
2 participants