Thanks to visit codestin.com
Credit goes to github.com

Skip to content

RFC Should cross-validation splitters validate that all classes are represented in each split? #29558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
glemaitre opened this issue Jul 25, 2024 · 6 comments
Labels

Comments

@glemaitre
Copy link
Member

glemaitre commented Jul 25, 2024

This is a follow-up to the issue raised in #29554. However, I recall other issues raised for CV estimator in general.

So the context is the following: a CV estimator will use an internal cross-validation scheme. When we deal with a classifier, we don't have any safety mechanism in the CV to make sure that the classifier was at least trained on all classes. This could happen for two reasons on the top of the head: (i) the target is sorted and the training folds does not contain all classes and (ii) a class is potentially underrepresented and not selected.

In all cases, if the fit does not fail, we still obtain a broken estimator. If it breaks at predict at least this is not silently giving some wrong predictions but this is not a given. However, we don't provide a direct feedback to the user of what went wrong.

So I'm wondering if we should have a sort of mechanism in the CV strategies to ensure that at least all classes have been observed at fit time. I don't think that we should touch the estimators because we will repeat a lot of code and fundamentally the issue is raised because of the CV strategies.

NB: the same issue could happen with a simple classifier in a cross-validation to evaluate it. This is not necessarily a CV estimator.

@github-actions github-actions bot added the Needs Triage Issue requires triage label Jul 25, 2024
@glemaitre glemaitre added RFC and removed Needs Triage Issue requires triage labels Jul 25, 2024
@adam2392
Copy link
Member

adam2392 commented Jul 25, 2024

I think this is a possible issue in any classifier that inherently splits the data. E.g. in random forest, you can split the data, but in imbalanced cases, it is possible to get trees with a single class, which is useless. This makes me wonder why we don't stratify by default in the underlying CV/sample-splitting strategies.

E.g.

@pytest.mark.parametrize("Forest", FOREST_CLASSIFIERS.values())
def test_forest_sampling_imbalance(Forest):
    # Check that the forest can handle imbalanced classes with sampling
    # strategies.
    X, _ = datasets.make_classification(
        n_samples=1000, n_features=20, n_informative=2, n_redundant=2, n_classes=2
    )
    y = np.zeros(1000)
    y[10] = 1

    forest = Forest(
        n_estimators=10,
        random_state=0,
    )
    forest.fit(X, y)
    
    assert all(assert_array_equal(np.unique(y), np.unique(estimator.classes_)) for estimator in forest.estimators_)

@ogrisel ogrisel changed the title RFC Do cross-validation should validate training set for classifier RFC Should cross-validation splitters validate training set before classifier's own validation? Sep 30, 2024
@ogrisel
Copy link
Member

ogrisel commented Sep 30, 2024

Do have a particular code snippet to illustrate this problem, including the current traceback and error message one would get as a result?

@ogrisel ogrisel changed the title RFC Should cross-validation splitters validate training set before classifier's own validation? RFC Should cross-validation splitters validate that all classes are represented in each split? Sep 30, 2024
@adam2392
Copy link
Member

adam2392 commented Sep 30, 2024

This particular example is just for trees, where the randomness of the bootstrap helps alleviate this "usually". I imagine, this would occur in other areas of the library where bootstrap is not done. I am not suggesting this is a major bug per se, but when imbalance occurs, there is a chance this degeneracy occurs.

Essentially, the trees that only contain one class is useless and discarded more or less. In other estimators, if the cross-validation results in not all classes being present, it sounds like there might be a silent error(?) Would have to investigate.

MWE

@pytest.mark.parametrize("Forest", FOREST_CLASSIFIERS.values())
def test_forest_sampling_imbalance(Forest):
    # Check that the forest can handle imbalanced classes with sampling
    # strategies.
    X, _ = datasets.make_classification(
        n_samples=100, n_features=20, n_informative=4, n_redundant=2, n_classes=2
    )
    y = np.zeros(100)
    y[10:15] = 1

    forest = Forest(
        n_estimators=100,
        random_state=0,
    )
    forest.fit(X, y)

Add the following code within _parallel_build_trees of ensemble/_forest.py

from numpy.testing import assert_array_equal
        assert_array_equal(np.unique(y[curr_sample_weight > 0]), np.unique(y)), f"{np.unique(y[curr_sample_weight > 0])} != {np.unique(y)}"

Error trace:

>       forest.fit(X, y)

sklearn/ensemble/tests/test_forest.py:1881: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
sklearn/base.py:1241: in wrapper
    return fit_method(estimator, *args, **kwargs)
sklearn/ensemble/_forest.py:489: in fit
    trees = Parallel(
sklearn/utils/parallel.py:77: in __call__
    return super().__call__(iterable_with_config)
../../miniforge3/envs/sklearn-dev/lib/python3.10/site-packages/joblib/parallel.py:1918: in __call__
    return output if self.return_generator else list(output)
../../miniforge3/envs/sklearn-dev/lib/python3.10/site-packages/joblib/parallel.py:1847: in _get_sequential_output
    res = func(*args, **kwargs)
sklearn/utils/parallel.py:139: in __call__
    return self.function(*args, **kwargs)
sklearn/ensemble/_forest.py:183: in _parallel_build_trees
    assert_array_equal(np.unique(y[curr_sample_weight > 0]), np.unique(y)), f"{np.unique(y[curr_sample_weight > 0])} != {np.unique(y)}"
../../miniforge3/envs/sklearn-dev/lib/python3.10/site-packages/numpy/_utils/__init__.py:85: in wrapper
    return fun(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = (<built-in function eq>, array([0.]), array([0., 1.]))
kwds = {'err_msg': '', 'header': 'Arrays are not equal', 'strict': False, 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Arrays are not equal
E           
E           (shapes (1,), (2,) mismatch)
E            ACTUAL: array([0.])
E            DESIRED: array([0., 1.])

../../miniforge3/envs/sklearn-dev/lib/python3.10/contextlib.py:79: AssertionError

@ogrisel
Copy link
Member

ogrisel commented Sep 30, 2024

Wouldn't be an easy way to change _parallel_build_trees to properly handle zero-weighted classes in the training set?

@ogrisel
Copy link
Member

ogrisel commented Sep 30, 2024

Maybe we could even introduce a common test to check that all classifiers accept being fit with some zero weighted classes.

@glemaitre
Copy link
Member Author

Another case with LogisticRegressionCV:

In [1]: import numpy as np
   ...: from sklearn.linear_model import LogisticRegressionCV
   ...: X = np.zeros((10, 1))
   ...: y = [1, 1, 1, 1, 1, 2, 2, 2, 2, 3]
   ...: logreg = LogisticRegressionCV(cv=5, scoring='neg_log_loss')
   ...: logreg.fit(X, y)
/Users/glemaitre/Documents/packages/scikit-learn/sklearn/model_selection/_split.py:779: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=5.
  warnings.warn(
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 6
      4 y = [1, 1, 1, 1, 1, 2, 2, 2, 2, 3]
      5 logreg = LogisticRegressionCV(cv=5, scoring='neg_log_loss')
----> 6 logreg.fit(X, y)

File ~/Documents/packages/scikit-learn/sklearn/base.py:1244, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
   1237     estimator._validate_params()
   1239 with config_context(
   1240     skip_parameter_validation=(
   1241         prefer_skip_nested_validation or global_skip_validation
   1242     )
   1243 ):
-> 1244     return fit_method(estimator, *args, **kwargs)

File ~/Documents/packages/scikit-learn/sklearn/linear_model/_logistic.py:2000, in LogisticRegressionCV.fit(self, X, y, sample_weight, **params)
   1997 else:
   1998     prefer = "processes"
-> 2000 fold_coefs_ = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, prefer=prefer)(
   2001     path_func(
   2002         X,
   2003         y,
   2004         train,
   2005         test,
   2006         pos_class=label,
   2007         Cs=self.Cs,
   2008         fit_intercept=self.fit_intercept,
   2009         penalty=self.penalty,
   2010         dual=self.dual,
   2011         solver=solver,
   2012         tol=self.tol,
   2013         max_iter=self.max_iter,
   2014         verbose=self.verbose,
   2015         class_weight=class_weight,
   2016         scoring=self.scoring,
   2017         multi_class=multi_class,
   2018         intercept_scaling=self.intercept_scaling,
   2019         random_state=self.random_state,
   2020         max_squared_sum=max_squared_sum,
   2021         sample_weight=sample_weight,
   2022         l1_ratio=l1_ratio,
   2023         score_params=routed_params.scorer.score,
   2024     )
   2025     for label in iter_encoded_labels
   2026     for train, test in folds
   2027     for l1_ratio in l1_ratios_
   2028 )
   2030 # _log_reg_scoring_path will output different shapes depending on the
   2031 # multi_class param, so we need to reshape the outputs accordingly.
   2032 # Cs is of shape (n_classes . n_folds . n_l1_ratios, n_Cs) and all the
   (...)
   2039 #  (n_classes, n_folds, n_Cs . n_l1_ratios) or
   2040 #  (1, n_folds, n_Cs . n_l1_ratios)
   2041 coefs_paths, Cs, scores, n_iter_ = zip(*fold_coefs_)

File ~/Documents/packages/scikit-learn/sklearn/utils/parallel.py:77, in Parallel.__call__(self, iterable)
     72 config = get_config()
     73 iterable_with_config = (
     74     (_with_config(delayed_func, config), args, kwargs)
     75     for delayed_func, args, kwargs in iterable
     76 )
---> 77 return super().__call__(iterable_with_config)

File ~/mambaforge/envs/sklearn_dev/lib/python3.10/site-packages/joblib/parallel.py:1863, in Parallel.__call__(self, iterable)
   1861     output = self._get_sequential_output(iterable)
   1862     next(output)
-> 1863     return output if self.return_generator else list(output)
   1865 # Let's create an ID that uniquely identifies the current call. If the
   1866 # call is interrupted early and that the same instance is immediately
   1867 # re-used, this id will be used to prevent workers that were
   1868 # concurrently finalizing a task from the previous call to run the
   1869 # callback.
   1870 with self._lock:

File ~/mambaforge/envs/sklearn_dev/lib/python3.10/site-packages/joblib/parallel.py:1792, in Parallel._get_sequential_output(self, iterable)
   1790 self.n_dispatched_batches += 1
   1791 self.n_dispatched_tasks += 1
-> 1792 res = func(*args, **kwargs)
   1793 self.n_completed_tasks += 1
   1794 self.print_progress()

File ~/Documents/packages/scikit-learn/sklearn/utils/parallel.py:139, in _FuncWrapper.__call__(self, *args, **kwargs)
    137     config = {}
    138 with config_context(**config):
--> 139     return self.function(*args, **kwargs)

File ~/Documents/packages/scikit-learn/sklearn/linear_model/_logistic.py:803, in _log_reg_scoring_path(X, y, train, test, pos_class, Cs, scoring, fit_intercept, max_iter, tol, class_weight, verbose, solver, penalty, dual, intercept_scaling, multi_class, random_state, max_squared_sum, sample_weight, l1_ratio, score_params)
    801         score_params = score_params or {}
    802         score_params = _check_method_params(X=X, params=score_params, indices=test)
--> 803         scores.append(scoring(log_reg, X_test, y_test, **score_params))
    804 return coefs, Cs, np.array(scores), n_iter

File ~/Documents/packages/scikit-learn/sklearn/metrics/_scorer.py:288, in _BaseScorer.__call__(self, estimator, X, y_true, sample_weight, **kwargs)
    285 if sample_weight is not None:
    286     _kwargs["sample_weight"] = sample_weight
--> 288 return self._score(partial(_cached_call, None), estimator, X, y_true, **_kwargs)

File ~/Documents/packages/scikit-learn/sklearn/metrics/_scorer.py:388, in _Scorer._score(self, method_caller, estimator, X, y_true, **kwargs)
    380 y_pred = method_caller(
    381     estimator,
    382     _get_response_method_name(response_method),
    383     X,
    384     pos_label=pos_label,
    385 )
    387 scoring_kwargs = {**self._kwargs, **kwargs}
--> 388 return self._sign * self._score_func(y_true, y_pred, **scoring_kwargs)

File ~/Documents/packages/scikit-learn/sklearn/utils/_param_validation.py:189, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    187 global_skip_validation = get_config()["skip_parameter_validation"]
    188 if global_skip_validation:
--> 189     return func(*args, **kwargs)
    191 func_sig = signature(func)
    193 # Map *args/**kwargs to the function signature

File ~/Documents/packages/scikit-learn/sklearn/metrics/_classification.py:3033, in log_loss(y_true, y_pred, normalize, sample_weight, labels)
   3031 if len(lb.classes_) != y_pred.shape[1]:
   3032     if labels is None:
-> 3033         raise ValueError(
   3034             "y_true and y_pred contain different number of "
   3035             "classes {0}, {1}. Please provide the true "
   3036             "labels explicitly through the labels argument. "
   3037             "Classes found in "
   3038             "y_true: {2}".format(
   3039                 transformed_labels.shape[1], y_pred.shape[1], lb.classes_
   3040             )
   3041         )
   3042     else:
   3043         raise ValueError(
   3044             "The number of classes in labels is different "
   3045             "from that in y_pred. Classes found in "
   3046             "labels: {0}".format(lb.classes_)
   3047         )

ValueError: y_true and y_pred contain different number of classes 2, 3. Please provide the true labels explicitly through the labels argument. Classes found in y_true: [0 1]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Development

No branches or pull requests

3 participants