diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index d8776653cd9e8..92983b0041b74 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -2,6 +2,23 @@ .. currentmodule:: sklearn +.. _changes_1_0_1: + +Version 1.0.1 +============= + +**In Development** + +Changelog +--------- + +:mod:`sklearn.calibration` +.......................... + +- |Fix| Fixed :class:`calibration.CalibratedClassifierCV` to handle correctly + `sample_weight` when `ensemble=False`. + :pr:`20638` by :user:`Julien Bohné `. + .. _changes_1_0: Version 1.0.0 diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 9ede41a775c3e..700f1ed6cc397 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -351,6 +351,11 @@ def fit(self, X, y, sample_weight=None): else: this_estimator = clone(base_estimator) _, method_name = _get_prediction_method(this_estimator) + fit_params = ( + {"sample_weight": sample_weight} + if sample_weight is not None and supports_sw + else None + ) pred_method = partial( cross_val_predict, estimator=this_estimator, @@ -359,6 +364,7 @@ def fit(self, X, y, sample_weight=None): cv=cv, method=method_name, n_jobs=self.n_jobs, + fit_params=fit_params, ) predictions = _compute_predictions( pred_method, method_name, X, n_classes diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index 040571df4681b..da1645a1c0fd6 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -166,6 +166,12 @@ def test_sample_weight(data, method, ensemble): X_train, y_train, sw_train = X[:n_samples], y[:n_samples], sample_weight[:n_samples] X_test = X[n_samples:] + scaler = StandardScaler() + X_train = scaler.fit_transform( + X_train + ) # compute mean, std and transform training data as well + X_test = scaler.transform(X_test) + base_estimator = LinearSVC(random_state=42) calibrated_clf = CalibratedClassifierCV( base_estimator, method=method, ensemble=ensemble @@ -182,6 +188,68 @@ def test_sample_weight(data, method, ensemble): assert diff > 0.1 +@pytest.mark.parametrize("method", ["sigmoid", "isotonic"]) +@pytest.mark.parametrize("ensemble", [True, False]) +def test_sample_weight_class_imbalanced(method, ensemble): + """Use an imbalanced dataset to check that `sample_weight` is taken into + account in the calibration estimator.""" + X, y = make_blobs((100, 1000), center_box=(-1, 1), random_state=42) + + # Compute weights to compensate for the unbalance of the dataset + weights = np.array([0.9, 0.1]) + sample_weight = weights[(y == 1).astype(int)] + + X_train, X_test, y_train, y_test, sw_train, sw_test = train_test_split( + X, y, sample_weight, stratify=y, random_state=42 + ) + + # FIXME: ideally we should create a `Pipeline` with the `StandardScaler` + # followed by the `LinearSVC`. However, `Pipeline` does not expose + # `sample_weight` and it will be silently ignored. + scaler = StandardScaler() + X_train = scaler.fit_transform(X_train) + X_test = scaler.transform(X_test) + + base_estimator = LinearSVC(random_state=42) + calibrated_clf = CalibratedClassifierCV( + base_estimator, method=method, ensemble=ensemble + ) + calibrated_clf.fit(X_train, y_train, sample_weight=sw_train) + predictions = calibrated_clf.predict_proba(X_test)[:, 1] + + assert brier_score_loss(y_test, predictions, sample_weight=sw_test) < 0.2 + + +@pytest.mark.parametrize("method", ["sigmoid", "isotonic"]) +def test_sample_weight_class_imbalanced_ensemble_equivalent(method): + X, y = make_blobs((100, 1000), center_box=(-1, 1), random_state=42) + + # Compute weigths to compensate the unbalance of the dataset + sample_weight = 9 * (y == 0) + 1 + + X_train, X_test, y_train, y_test, sw_train, sw_test = train_test_split( + X, y, sample_weight, stratify=y, random_state=42 + ) + + scaler = StandardScaler() + X_train = scaler.fit_transform( + X_train + ) # compute mean, std and transform training data as well + X_test = scaler.transform(X_test) + + predictions = [] + for ensemble in [True, False]: + base_estimator = LinearSVC(random_state=42) + calibrated_clf = CalibratedClassifierCV( + base_estimator, method=method, ensemble=ensemble + ) + calibrated_clf.fit(X_train, y_train, sample_weight=sw_train) + predictions.append(calibrated_clf.predict_proba(X_test)[:, 1]) + + diff = np.linalg.norm(predictions[0] - predictions[1]) + assert diff < 1.5 + + @pytest.mark.parametrize("method", ["sigmoid", "isotonic"]) @pytest.mark.parametrize("ensemble", [True, False]) def test_parallel_execution(data, method, ensemble):