-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
RFC Should cross-validation splitters validate that all classes are represented in each split? #29558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I think this is a possible issue in any classifier that inherently splits the data. E.g. in random forest, you can split the data, but in imbalanced cases, it is possible to get trees with a single class, which is useless. This makes me wonder why we don't stratify by default in the underlying CV/sample-splitting strategies. E.g. @pytest.mark.parametrize("Forest", FOREST_CLASSIFIERS.values())
def test_forest_sampling_imbalance(Forest):
# Check that the forest can handle imbalanced classes with sampling
# strategies.
X, _ = datasets.make_classification(
n_samples=1000, n_features=20, n_informative=2, n_redundant=2, n_classes=2
)
y = np.zeros(1000)
y[10] = 1
forest = Forest(
n_estimators=10,
random_state=0,
)
forest.fit(X, y)
assert all(assert_array_equal(np.unique(y), np.unique(estimator.classes_)) for estimator in forest.estimators_) |
Do have a particular code snippet to illustrate this problem, including the current traceback and error message one would get as a result? |
This particular example is just for trees, where the randomness of the bootstrap helps alleviate this "usually". I imagine, this would occur in other areas of the library where bootstrap is not done. I am not suggesting this is a major bug per se, but when imbalance occurs, there is a chance this degeneracy occurs. Essentially, the trees that only contain one class is useless and discarded more or less. In other estimators, if the cross-validation results in not all classes being present, it sounds like there might be a silent error(?) Would have to investigate. MWE
Add the following code within
Error trace:
|
Wouldn't be an easy way to change |
Maybe we could even introduce a common test to check that all classifiers accept being fit with some zero weighted classes. |
Another case with In [1]: import numpy as np
...: from sklearn.linear_model import LogisticRegressionCV
...: X = np.zeros((10, 1))
...: y = [1, 1, 1, 1, 1, 2, 2, 2, 2, 3]
...: logreg = LogisticRegressionCV(cv=5, scoring='neg_log_loss')
...: logreg.fit(X, y) /Users/glemaitre/Documents/packages/scikit-learn/sklearn/model_selection/_split.py:779: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=5.
warnings.warn(
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[1], line 6
4 y = [1, 1, 1, 1, 1, 2, 2, 2, 2, 3]
5 logreg = LogisticRegressionCV(cv=5, scoring='neg_log_loss')
----> 6 logreg.fit(X, y)
File ~/Documents/packages/scikit-learn/sklearn/base.py:1244, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
1237 estimator._validate_params()
1239 with config_context(
1240 skip_parameter_validation=(
1241 prefer_skip_nested_validation or global_skip_validation
1242 )
1243 ):
-> 1244 return fit_method(estimator, *args, **kwargs)
File ~/Documents/packages/scikit-learn/sklearn/linear_model/_logistic.py:2000, in LogisticRegressionCV.fit(self, X, y, sample_weight, **params)
1997 else:
1998 prefer = "processes"
-> 2000 fold_coefs_ = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, prefer=prefer)(
2001 path_func(
2002 X,
2003 y,
2004 train,
2005 test,
2006 pos_class=label,
2007 Cs=self.Cs,
2008 fit_intercept=self.fit_intercept,
2009 penalty=self.penalty,
2010 dual=self.dual,
2011 solver=solver,
2012 tol=self.tol,
2013 max_iter=self.max_iter,
2014 verbose=self.verbose,
2015 class_weight=class_weight,
2016 scoring=self.scoring,
2017 multi_class=multi_class,
2018 intercept_scaling=self.intercept_scaling,
2019 random_state=self.random_state,
2020 max_squared_sum=max_squared_sum,
2021 sample_weight=sample_weight,
2022 l1_ratio=l1_ratio,
2023 score_params=routed_params.scorer.score,
2024 )
2025 for label in iter_encoded_labels
2026 for train, test in folds
2027 for l1_ratio in l1_ratios_
2028 )
2030 # _log_reg_scoring_path will output different shapes depending on the
2031 # multi_class param, so we need to reshape the outputs accordingly.
2032 # Cs is of shape (n_classes . n_folds . n_l1_ratios, n_Cs) and all the
(...)
2039 # (n_classes, n_folds, n_Cs . n_l1_ratios) or
2040 # (1, n_folds, n_Cs . n_l1_ratios)
2041 coefs_paths, Cs, scores, n_iter_ = zip(*fold_coefs_)
File ~/Documents/packages/scikit-learn/sklearn/utils/parallel.py:77, in Parallel.__call__(self, iterable)
72 config = get_config()
73 iterable_with_config = (
74 (_with_config(delayed_func, config), args, kwargs)
75 for delayed_func, args, kwargs in iterable
76 )
---> 77 return super().__call__(iterable_with_config)
File ~/mambaforge/envs/sklearn_dev/lib/python3.10/site-packages/joblib/parallel.py:1863, in Parallel.__call__(self, iterable)
1861 output = self._get_sequential_output(iterable)
1862 next(output)
-> 1863 return output if self.return_generator else list(output)
1865 # Let's create an ID that uniquely identifies the current call. If the
1866 # call is interrupted early and that the same instance is immediately
1867 # re-used, this id will be used to prevent workers that were
1868 # concurrently finalizing a task from the previous call to run the
1869 # callback.
1870 with self._lock:
File ~/mambaforge/envs/sklearn_dev/lib/python3.10/site-packages/joblib/parallel.py:1792, in Parallel._get_sequential_output(self, iterable)
1790 self.n_dispatched_batches += 1
1791 self.n_dispatched_tasks += 1
-> 1792 res = func(*args, **kwargs)
1793 self.n_completed_tasks += 1
1794 self.print_progress()
File ~/Documents/packages/scikit-learn/sklearn/utils/parallel.py:139, in _FuncWrapper.__call__(self, *args, **kwargs)
137 config = {}
138 with config_context(**config):
--> 139 return self.function(*args, **kwargs)
File ~/Documents/packages/scikit-learn/sklearn/linear_model/_logistic.py:803, in _log_reg_scoring_path(X, y, train, test, pos_class, Cs, scoring, fit_intercept, max_iter, tol, class_weight, verbose, solver, penalty, dual, intercept_scaling, multi_class, random_state, max_squared_sum, sample_weight, l1_ratio, score_params)
801 score_params = score_params or {}
802 score_params = _check_method_params(X=X, params=score_params, indices=test)
--> 803 scores.append(scoring(log_reg, X_test, y_test, **score_params))
804 return coefs, Cs, np.array(scores), n_iter
File ~/Documents/packages/scikit-learn/sklearn/metrics/_scorer.py:288, in _BaseScorer.__call__(self, estimator, X, y_true, sample_weight, **kwargs)
285 if sample_weight is not None:
286 _kwargs["sample_weight"] = sample_weight
--> 288 return self._score(partial(_cached_call, None), estimator, X, y_true, **_kwargs)
File ~/Documents/packages/scikit-learn/sklearn/metrics/_scorer.py:388, in _Scorer._score(self, method_caller, estimator, X, y_true, **kwargs)
380 y_pred = method_caller(
381 estimator,
382 _get_response_method_name(response_method),
383 X,
384 pos_label=pos_label,
385 )
387 scoring_kwargs = {**self._kwargs, **kwargs}
--> 388 return self._sign * self._score_func(y_true, y_pred, **scoring_kwargs)
File ~/Documents/packages/scikit-learn/sklearn/utils/_param_validation.py:189, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
187 global_skip_validation = get_config()["skip_parameter_validation"]
188 if global_skip_validation:
--> 189 return func(*args, **kwargs)
191 func_sig = signature(func)
193 # Map *args/**kwargs to the function signature
File ~/Documents/packages/scikit-learn/sklearn/metrics/_classification.py:3033, in log_loss(y_true, y_pred, normalize, sample_weight, labels)
3031 if len(lb.classes_) != y_pred.shape[1]:
3032 if labels is None:
-> 3033 raise ValueError(
3034 "y_true and y_pred contain different number of "
3035 "classes {0}, {1}. Please provide the true "
3036 "labels explicitly through the labels argument. "
3037 "Classes found in "
3038 "y_true: {2}".format(
3039 transformed_labels.shape[1], y_pred.shape[1], lb.classes_
3040 )
3041 )
3042 else:
3043 raise ValueError(
3044 "The number of classes in labels is different "
3045 "from that in y_pred. Classes found in "
3046 "labels: {0}".format(lb.classes_)
3047 )
ValueError: y_true and y_pred contain different number of classes 2, 3. Please provide the true labels explicitly through the labels argument. Classes found in y_true: [0 1] |
Uh oh!
There was an error while loading. Please reload this page.
This is a follow-up to the issue raised in #29554. However, I recall other issues raised for CV estimator in general.
So the context is the following: a CV estimator will use an internal cross-validation scheme. When we deal with a classifier, we don't have any safety mechanism in the CV to make sure that the classifier was at least trained on all classes. This could happen for two reasons on the top of the head: (i) the target is sorted and the training folds does not contain all classes and (ii) a class is potentially underrepresented and not selected.
In all cases, if the
fit
does not fail, we still obtain a broken estimator. If it breaks atpredict
at least this is not silently giving some wrong predictions but this is not a given. However, we don't provide a direct feedback to the user of what went wrong.So I'm wondering if we should have a sort of mechanism in the CV strategies to ensure that at least all classes have been observed at
fit
time. I don't think that we should touch the estimators because we will repeat a lot of code and fundamentally the issue is raised because of the CV strategies.NB: the same issue could happen with a simple classifier in a cross-validation to evaluate it. This is not necessarily a CV estimator.
The text was updated successfully, but these errors were encountered: