diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index e4e965aca02d8..3ef8a7653b5f7 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -114,6 +114,12 @@ Changelog :mod:`sklearn.ensemble` ....................... +- |Feature| :class:`ensemble.HistGradientBoostingRegressor` now supports + the Gamma deviance loss via `loss="gamma"`. + Using the Gamma deviance as loss function comes in handy for modelling skewed + distributed, strictly positive valued targets. + :pr:`22409` by :user:`Christian Lorentzen `. + - |Feature| Compute a custom out-of-bag score by passing a callable to :class:`ensemble.RandomForestClassifier`, :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor`. diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 74322eef79bbc..31069fe14ee41 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -13,6 +13,7 @@ _LOSSES, BaseLoss, HalfBinomialLoss, + HalfGammaLoss, HalfMultinomialLoss, HalfPoissonLoss, PinballLoss, @@ -43,6 +44,7 @@ _LOSSES.update( { "poisson": HalfPoissonLoss, + "gamma": HalfGammaLoss, "quantile": PinballLoss, "binary_crossentropy": HalfBinomialLoss, "categorical_crossentropy": HalfMultinomialLoss, @@ -1204,13 +1206,14 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): Parameters ---------- - loss : {'squared_error', 'absolute_error', 'poisson', 'quantile'}, \ + loss : {'squared_error', 'absolute_error', 'gamma', 'poisson', 'quantile'}, \ default='squared_error' The loss function to use in the boosting process. Note that the - "squared error" and "poisson" losses actually implement - "half least squares loss" and "half poisson deviance" to simplify the - computation of the gradient. Furthermore, "poisson" loss internally - uses a log-link and requires ``y >= 0``. + "squared error", "gamma" and "poisson" losses actually implement + "half least squares loss", "half gamma deviance" and "half poisson + deviance" to simplify the computation of the gradient. Furthermore, + "gamma" and "poisson" losses internally use a log-link, "gamma" + requires ``y > 0`` and "poisson" requires ``y >= 0``. "quantile" uses the pinball loss. .. versionchanged:: 0.23 @@ -1219,6 +1222,9 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): .. versionchanged:: 1.1 Added option 'quantile'. + .. versionchanged:: 1.3 + Added option 'gamma'. + quantile : float, default=None If loss is "quantile", this parameter specifies which quantile to be estimated and must be between 0 and 1. @@ -1418,7 +1424,15 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): _parameter_constraints: dict = { **BaseHistGradientBoosting._parameter_constraints, "loss": [ - StrOptions({"squared_error", "absolute_error", "poisson", "quantile"}), + StrOptions( + { + "squared_error", + "absolute_error", + "poisson", + "gamma", + "quantile", + } + ), BaseLoss, ], "quantile": [Interval(Real, 0, 1, closed="both"), None], @@ -1514,7 +1528,11 @@ def _encode_y(self, y): # Just convert y to the expected dtype self.n_trees_per_iteration_ = 1 y = y.astype(Y_DTYPE, copy=False) - if self.loss == "poisson": + if self.loss == "gamma": + # Ensure y > 0 + if not np.all(y > 0): + raise ValueError("loss='gamma' requires strictly positive y.") + elif self.loss == "poisson": # Ensure y >= 0 and sum(y) > 0 if not (np.all(y >= 0) and np.sum(y) > 0): raise ValueError( diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py index f5c373ed84558..a697d385140d5 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py @@ -11,6 +11,17 @@ @pytest.mark.parametrize("seed", range(5)) +@pytest.mark.parametrize( + "loss", + [ + "squared_error", + "poisson", + pytest.param( + "gamma", + marks=pytest.mark.skip("LightGBM with gamma loss has larger deviation."), + ), + ], +) @pytest.mark.parametrize("min_samples_leaf", (1, 20)) @pytest.mark.parametrize( "n_samples, max_leaf_nodes", @@ -19,7 +30,9 @@ (1000, 8), ], ) -def test_same_predictions_regression(seed, min_samples_leaf, n_samples, max_leaf_nodes): +def test_same_predictions_regression( + seed, loss, min_samples_leaf, n_samples, max_leaf_nodes +): # Make sure sklearn has the same predictions as lightgbm for easy targets. # # In particular when the size of the trees are bound and the number of @@ -33,7 +46,7 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, max_leaf # is not exactly the same. To avoid this issue we only compare the # predictions on the test set when the number of samples is large enough # and max_leaf_nodes is low enough. - # - To ignore discrepancies caused by small differences the binning + # - To ignore discrepancies caused by small differences in the binning # strategy, data is pre-binned if n_samples > 255. # - We don't check the absolute_error loss here. This is because # LightGBM's computation of the median (used for the initial value of @@ -52,6 +65,10 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, max_leaf n_samples=n_samples, n_features=5, n_informative=5, random_state=0 ) + if loss in ("gamma", "poisson"): + # make the target positive + y = np.abs(y) + np.mean(np.abs(y)) + if n_samples > 255: # bin data and convert it to float32 so that the estimator doesn't # treat it as pre-binned @@ -60,6 +77,7 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, max_leaf X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) est_sklearn = HistGradientBoostingRegressor( + loss=loss, max_iter=max_iter, max_bins=max_bins, learning_rate=1, @@ -68,6 +86,7 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, max_leaf max_leaf_nodes=max_leaf_nodes, ) est_lightgbm = get_equivalent_estimator(est_sklearn, lib="lightgbm") + est_lightgbm.set_params(min_sum_hessian_in_leaf=0) est_lightgbm.fit(X_train, y_train) est_sklearn.fit(X_train, y_train) @@ -77,14 +96,24 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, max_leaf pred_lightgbm = est_lightgbm.predict(X_train) pred_sklearn = est_sklearn.predict(X_train) - # less than 1% of the predictions are different up to the 3rd decimal - assert np.mean(abs(pred_lightgbm - pred_sklearn) > 1e-3) < 0.011 - - if max_leaf_nodes < 10 and n_samples >= 1000: + if loss in ("gamma", "poisson"): + # More than 65% of the predictions must be close up to the 2nd decimal. + # TODO: We are not entirely satisfied with this lax comparison, but the root + # cause is not clear, maybe algorithmic differences. One such example is the + # poisson_max_delta_step parameter of LightGBM which does not exist in HGBT. + assert ( + np.mean(np.isclose(pred_lightgbm, pred_sklearn, rtol=1e-2, atol=1e-2)) + > 0.65 + ) + else: + # Less than 1% of the predictions may deviate more than 1e-3 in relative terms. + assert np.mean(np.isclose(pred_lightgbm, pred_sklearn, rtol=1e-3)) > 1 - 0.01 + + if max_leaf_nodes < 10 and n_samples >= 1000 and loss in ("squared_error",): pred_lightgbm = est_lightgbm.predict(X_test) pred_sklearn = est_sklearn.predict(X_test) - # less than 1% of the predictions are different up to the 4th decimal - assert np.mean(abs(pred_lightgbm - pred_sklearn) > 1e-4) < 0.01 + # Less than 1% of the predictions may deviate more than 1e-4 in relative terms. + assert np.mean(np.isclose(pred_lightgbm, pred_sklearn, rtol=1e-4)) > 1 - 0.01 @pytest.mark.parametrize("seed", range(5)) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 2a89b834672db..7e774d9f09f45 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -17,7 +17,7 @@ from sklearn.base import clone, BaseEstimator, TransformerMixin from sklearn.base import is_regressor from sklearn.pipeline import make_pipeline -from sklearn.metrics import mean_poisson_deviance +from sklearn.metrics import mean_gamma_deviance, mean_poisson_deviance from sklearn.dummy import DummyRegressor from sklearn.exceptions import NotFittedError from sklearn.compose import make_column_transformer @@ -248,8 +248,64 @@ def test_absolute_error_sample_weight(): gbdt.fit(X, y, sample_weight=sample_weight) +@pytest.mark.parametrize("y", [([1.0, -2.0, 0.0]), ([0.0, 1.0, 2.0])]) +def test_gamma_y_positive(y): + # Test that ValueError is raised if any y_i <= 0. + err_msg = r"loss='gamma' requires strictly positive y." + gbdt = HistGradientBoostingRegressor(loss="gamma", random_state=0) + with pytest.raises(ValueError, match=err_msg): + gbdt.fit(np.zeros(shape=(len(y), 1)), y) + + +def test_gamma(): + # For a Gamma distributed target, we expect an HGBT trained with the Gamma deviance + # (loss) to give better results than an HGBT with any other loss function, measured + # in out-of-sample Gamma deviance as metric/score. + # Note that squared error could potentially predict negative values which is + # invalid (np.inf) for the Gamma deviance. A Poisson HGBT (having a log link) + # does not have that defect. + # Important note: It seems that a Poisson HGBT almost always has better + # out-of-sample performance than the Gamma HGBT, measured in Gamma deviance. + # LightGBM shows the same behaviour. Hence, we only compare to a squared error + # HGBT, but not to a Poisson deviance HGBT. + rng = np.random.RandomState(42) + n_train, n_test, n_features = 500, 100, 20 + X = make_low_rank_matrix( + n_samples=n_train + n_test, + n_features=n_features, + random_state=rng, + ) + # We create a log-linear Gamma model. This gives y.min ~ 1e-2, y.max ~ 1e2 + coef = rng.uniform(low=-10, high=20, size=n_features) + # Numpy parametrizes gamma(shape=k, scale=theta) with mean = k * theta and + # variance = k * theta^2. We parametrize it instead with mean = exp(X @ coef) + # and variance = dispersion * mean^2 by setting k = 1 / dispersion, + # theta = dispersion * mean. + dispersion = 0.5 + y = rng.gamma(shape=1 / dispersion, scale=dispersion * np.exp(X @ coef)) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=n_test, random_state=rng + ) + gbdt_gamma = HistGradientBoostingRegressor(loss="gamma", random_state=123) + gbdt_mse = HistGradientBoostingRegressor(loss="squared_error", random_state=123) + dummy = DummyRegressor(strategy="mean") + for model in (gbdt_gamma, gbdt_mse, dummy): + model.fit(X_train, y_train) + + for X, y in [(X_train, y_train), (X_test, y_test)]: + loss_gbdt_gamma = mean_gamma_deviance(y, gbdt_gamma.predict(X)) + # We restrict the squared error HGBT to predict at least the minimum seen y at + # train time to make it strictly positive. + loss_gbdt_mse = mean_gamma_deviance( + y, np.maximum(np.min(y_train), gbdt_mse.predict(X)) + ) + loss_dummy = mean_gamma_deviance(y, dummy.predict(X)) + assert loss_gbdt_gamma < loss_dummy + assert loss_gbdt_gamma < loss_gbdt_mse + + @pytest.mark.parametrize("quantile", [0.2, 0.5, 0.8]) -def test_asymmetric_error(quantile): +def test_quantile_asymmetric_error(quantile): """Test quantile regression for asymmetric distributed targets.""" n_samples = 10_000 rng = np.random.RandomState(42) diff --git a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx index d2123ecc61510..1c2f9f3db69e1 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx @@ -41,6 +41,8 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None): 'squared_error': 'regression_l2', 'absolute_error': 'regression_l1', 'log_loss': 'binary' if n_classes == 2 else 'multiclass', + 'gamma': 'gamma', + 'poisson': 'poisson', } lightgbm_params = { @@ -53,13 +55,14 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None): 'reg_lambda': sklearn_params['l2_regularization'], 'max_bin': sklearn_params['max_bins'], 'min_data_in_bin': 1, - 'min_child_weight': 1e-3, + 'min_child_weight': 1e-3, # alias for 'min_sum_hessian_in_leaf' 'min_sum_hessian_in_leaf': 1e-3, 'min_split_gain': 0, 'verbosity': 10 if sklearn_params['verbose'] else -10, 'boost_from_average': True, 'enable_bundle': False, # also makes feature order consistent 'subsample_for_bin': _BinMapper().subsample, + 'poisson_max_delta_step': 1e-12, } if sklearn_params['loss'] == 'log_loss' and n_classes > 2: @@ -76,6 +79,8 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None): 'squared_error': 'reg:linear', 'absolute_error': 'LEAST_ABSOLUTE_DEV_NOT_SUPPORTED', 'log_loss': 'reg:logistic' if n_classes == 2 else 'multi:softmax', + 'gamma': 'reg:gamma', + 'poisson': 'count:poisson', } xgboost_params = { @@ -100,6 +105,8 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None): # catboost does not support MAE when leaf_estimation_method is Newton 'absolute_error': 'LEAST_ASBOLUTE_DEV_NOT_SUPPORTED', 'log_loss': 'Logloss' if n_classes == 2 else 'MultiClass', + 'gamma': None, + 'poisson': 'Poisson', } catboost_params = {