From 29adc46cf94f45749df20c6732cfee6e6f652ffe Mon Sep 17 00:00:00 2001 From: Yao Xiao <108576690+Charlie-XIAO@users.noreply.github.com> Date: Sun, 1 Oct 2023 20:45:46 +0800 Subject: [PATCH 01/13] support monotonic constraints in GradientBoosting* (first working version subject to change) --- sklearn/ensemble/_gb.py | 76 +++++++++++++++++++++-- sklearn/tree/_classes.py | 2 +- sklearn/tree/tests/test_monotonic_tree.py | 18 +++++- 3 files changed, 87 insertions(+), 9 deletions(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 5982f8a7fb952..25777578df830 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -130,6 +130,7 @@ def _update_terminal_regions( sample_mask, learning_rate=0.1, k=0, + line_search=True, ): """Update the leaf values to be predicted by the tree and raw_prediction. @@ -172,11 +173,14 @@ def _update_terminal_regions( ``learning_rate``. k : int, default=0 The index of the estimator being updated. + line_search : bool, default=True + Whether line search must be performed. Line search must not be + performed under monotonic constraints. """ # compute leaf for each sample in ``X``. terminal_regions = tree.apply(X) - if not isinstance(loss, HalfSquaredError): + if line_search and not isinstance(loss, HalfSquaredError): # mask all which are not in sample mask. masked_terminal_regions = terminal_regions.copy() masked_terminal_regions[~sample_mask] = -1 @@ -360,7 +364,6 @@ class BaseGradientBoosting(BaseEnsemble, metaclass=ABCMeta): "tol": [Interval(Real, 0.0, None, closed="left")], } _parameter_constraints.pop("splitter") - _parameter_constraints.pop("monotonic_cst") @abstractmethod def __init__( @@ -387,6 +390,7 @@ def __init__( validation_fraction=0.1, n_iter_no_change=None, tol=1e-4, + monotonic_cst=None, ): self.n_estimators = n_estimators self.learning_rate = learning_rate @@ -409,6 +413,7 @@ def __init__( self.validation_fraction = validation_fraction self.n_iter_no_change = n_iter_no_change self.tol = tol + self.monotonic_cst = monotonic_cst @abstractmethod def _encode_y(self, y=None, sample_weight=None): @@ -473,6 +478,7 @@ def _fit_stage( max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, ccp_alpha=self.ccp_alpha, + monotonic_cst=self.monotonic_cst, ) if self.subsample < 1.0: @@ -497,6 +503,7 @@ def _fit_stage( sample_mask, learning_rate=self.learning_rate, k=k, + line_search=self.monotonic_cst is None, ) # add tree to ensemble @@ -645,13 +652,23 @@ def fit(self, X, y, sample_weight=None, monitor=None): if not self.warm_start: self._clear_state() - # Check input # Since check_array converts both X and y to the same dtype, but the # trees use different types for X and y, checking them separately. - X, y = self._validate_data( - X, y, accept_sparse=["csr", "csc", "coo"], dtype=DTYPE, multi_output=True + X, + y, + accept_sparse=["csr", "csc", "coo"], + dtype=DTYPE, + multi_output=True, ) + + # Raise now instead of specifying multi_output=False because we want a more + # explicit error message. + if self.monotonic_cst is not None and len(y.shape) > 1: + raise ValueError( + "Monotonicity constraints are not supported with multiple outputs" + ) + sample_weight_is_none = sample_weight is None sample_weight = _check_sample_weight(sample_weight, X) if sample_weight_is_none: @@ -1304,6 +1321,25 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): .. versionadded:: 0.22 + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If ``monotonic_cst`` is ``None``, no constraints are applied. + + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + Attributes ---------- n_estimators_ : int @@ -1463,6 +1499,7 @@ def __init__( n_iter_no_change=None, tol=1e-4, ccp_alpha=0.0, + monotonic_cst=None, ): super().__init__( loss=loss, @@ -1485,6 +1522,7 @@ def __init__( n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha, + monotonic_cst=monotonic_cst, ) def _encode_y(self, y, sample_weight): @@ -1504,6 +1542,13 @@ def _encode_y(self, y, sample_weight): # From here on, it is additional to the HGBT case. # expose n_classes_ attribute self.n_classes_ = n_classes + + if self.monotonic_cst is not None and self.n_classes_ > 2: + raise ValueError( + "Monotonicity constraints are not supported with multiclass " + "classification" + ) + if sample_weight is None: n_trim_classes = n_classes else: @@ -1915,6 +1960,25 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): .. versionadded:: 0.22 + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If ``monotonic_cst`` is ``None``, no constraints are applied. + + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + Attributes ---------- n_estimators_ : int @@ -2058,6 +2122,7 @@ def __init__( n_iter_no_change=None, tol=1e-4, ccp_alpha=0.0, + monotonic_cst=None, ): super().__init__( loss=loss, @@ -2081,6 +2146,7 @@ def __init__( n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha, + monotonic_cst=monotonic_cst, ) def _encode_y(self, y=None, sample_weight=None): diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 03ba2f108bbdd..2217c398ae167 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -416,7 +416,7 @@ def _fit( else: if self.n_outputs_ > 1: raise ValueError( - "Monotonicity constraints are not supported with multiple outputs." + "Monotonicity constraints are not supported with multiple outputs" ) # Check to correct monotonicity constraint' specification, # by applying element-wise logical conjunction diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index fe2f863d314ed..4358863529875 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -5,6 +5,8 @@ from sklearn.ensemble import ( ExtraTreesClassifier, ExtraTreesRegressor, + GradientBoostingClassifier, + GradientBoostingRegressor, RandomForestClassifier, RandomForestRegressor, ) @@ -21,10 +23,12 @@ TREE_BASED_CLASSIFIER_CLASSES = TREE_CLASSIFIER_CLASSES + [ RandomForestClassifier, ExtraTreesClassifier, + GradientBoostingClassifier, ] TREE_BASED_REGRESSOR_CLASSES = TREE_REGRESSOR_CLASSES + [ RandomForestRegressor, ExtraTreesRegressor, + GradientBoostingRegressor, ] @@ -91,7 +95,9 @@ def test_monotonic_constraints_classifications( @pytest.mark.parametrize("TreeRegressor", TREE_BASED_REGRESSOR_CLASSES) @pytest.mark.parametrize("depth_first_builder", (True, False)) @pytest.mark.parametrize("sparse_splitter", (True, False)) -@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error")) +@pytest.mark.parametrize( + "criterion", ("friedman_mse", "squared_error", "absolute_error") +) @pytest.mark.parametrize("csc_container", CSC_CONTAINERS) def test_monotonic_constraints_regressions( TreeRegressor, @@ -101,6 +107,12 @@ def test_monotonic_constraints_regressions( global_random_seed, csc_container, ): + if ( + criterion in ("absolute_error", "poisson") + and TreeRegressor is GradientBoostingRegressor + ): + pytest.skip(f"{TreeRegressor.__name__} does not support criterion={criterion}") + n_samples = 1000 n_samples_train = 900 # Build a regression task using 5 informative features @@ -133,8 +145,8 @@ def test_monotonic_constraints_regressions( est = TreeRegressor( max_depth=8, monotonic_cst=monotonic_cst, - criterion=criterion, max_leaf_nodes=n_samples_train, + criterion=criterion, ) if hasattr(est, "random_state"): est.set_params(random_state=global_random_seed) @@ -179,7 +191,7 @@ def test_multiple_output_raises(TreeClassifier): est = TreeClassifier( max_depth=None, monotonic_cst=np.array([-1, 1]), random_state=0 ) - msg = "Monotonicity constraints are not supported with multiple output" + msg = "Monotonicity constraints are not supported with multiple outputs" with pytest.raises(ValueError, match=msg): est.fit(X, y) From ce4d65a6fb29e1b36c93cd87c0835ee9fc4aac54 Mon Sep 17 00:00:00 2001 From: Yao Xiao <108576690+Charlie-XIAO@users.noreply.github.com> Date: Tue, 3 Oct 2023 21:46:58 +0800 Subject: [PATCH 02/13] changelog added --- doc/whats_new/v1.4.rst | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 6afc5011ded30..4141eefce32d9 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -217,12 +217,14 @@ Changelog for regression. :pr:`26391` by `Thomas Fan`_. - |Feature| :class:`ensemble.RandomForestClassifier`, - :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier` - and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints, + :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier`, + :class:`ensemble.ExtraTreesRegressor`, :class:`ensemble.GradientBoostingClassifier`, + and :class:`ensemble.GradientBoostingRegressor` now support monotonic constraints, useful when features are supposed to have a positive/negative effect on the target. Missing values in the train data and multi-output targets are not supported. - :pr:`13649` by :user:`Samuel Ronsin `, - initiated by :user:`Patrick O'Reilly `. + :pr:`13649` by :user:`Samuel Ronsin ` (initiated by + :user:`Patrick O'Reilly `), and :pr:`27516` by + :user:`Yao Xiao `. - |Efficiency| :class:`ensemble.GradientBoostingClassifier` is faster, for binary and in particular for multiclass problems thanks to the private loss From c8ee23a2c64f5b50a6438b66c36b89398f7df067 Mon Sep 17 00:00:00 2001 From: Yao Xiao <108576690+Charlie-XIAO@users.noreply.github.com> Date: Tue, 3 Oct 2023 21:50:08 +0800 Subject: [PATCH 03/13] retrigger CI From ad0081395398942515c75ac7569f3564f55767bf Mon Sep 17 00:00:00 2001 From: Yao Xiao <108576690+Charlie-XIAO@users.noreply.github.com> Date: Mon, 29 Jan 2024 16:50:11 +0800 Subject: [PATCH 04/13] move changelog and remove unrelated changes --- doc/whats_new/v1.4.rst | 10 ++++------ doc/whats_new/v1.5.rst | 6 ++++++ sklearn/tree/_classes.py | 2 +- sklearn/tree/tests/test_monotonic_tree.py | 4 ++-- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 41af2ff5536b2..f786693b84ee0 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -504,14 +504,12 @@ Changelog :pr:`27139` by :user:`Christian Lorentzen `. - |Feature| :class:`ensemble.RandomForestClassifier`, - :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier`, - :class:`ensemble.ExtraTreesRegressor`, :class:`ensemble.GradientBoostingClassifier`, - and :class:`ensemble.GradientBoostingRegressor` now support monotonic constraints, + :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier` + and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints, useful when features are supposed to have a positive/negative effect on the target. Missing values in the train data and multi-output targets are not supported. - :pr:`13649` by :user:`Samuel Ronsin ` (initiated by - :user:`Patrick O'Reilly `), and :pr:`27516` by - :user:`Yao Xiao `. + :pr:`13649` by :user:`Samuel Ronsin `, + initiated by :user:`Patrick O'Reilly `. - |Efficiency| :class:`ensemble.HistGradientBoostingClassifier` and :class:`ensemble.HistGradientBoostingRegressor` are now a bit faster by reusing diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 5818e4b25f044..3132557cd035e 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -67,6 +67,12 @@ Changelog :class:`ensemble.HistGradientBoostingClassifier` by avoiding to call `predict_proba`. :pr:`27844` by :user:`Christian Lorentzen `. +- |Feature| :class:`ensemble.GradientBoostingClassifier` and + :class:`ensemble.GradientBoostingRegressor` now support monotonic constraints, + useful when features are supposed to have a positive/negative effect on the target. + Missing values in the train data and multi-output targets are not supported. + :pr:`27516` by :user:`Yao Xiao `. + :mod:`sklearn.feature_extraction` ................................. diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index e7fed77cf2b30..9f99d831a0990 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -394,7 +394,7 @@ def _fit( else: if self.n_outputs_ > 1: raise ValueError( - "Monotonicity constraints are not supported with multiple outputs" + "Monotonicity constraints are not supported with multiple outputs." ) # Check to correct monotonicity constraint' specification, # by applying element-wise logical conjunction diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 74675e15cefe4..3ae767022fa4a 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -151,8 +151,8 @@ def test_monotonic_constraints_regressions( est = TreeRegressor( max_depth=8, monotonic_cst=monotonic_cst, - max_leaf_nodes=n_samples_train, criterion=criterion, + max_leaf_nodes=n_samples_train, ) if hasattr(est, "random_state"): est.set_params(random_state=global_random_seed) @@ -197,7 +197,7 @@ def test_multiple_output_raises(TreeClassifier): est = TreeClassifier( max_depth=None, monotonic_cst=np.array([-1, 1]), random_state=0 ) - msg = "Monotonicity constraints are not supported with multiple outputs" + msg = "Monotonicity constraints are not supported with multiple output" with pytest.raises(ValueError, match=msg): est.fit(X, y) From 38460ccf768b32e22a5bdfb5eba20aeeebb2001a Mon Sep 17 00:00:00 2001 From: Yao Xiao <108576690+Charlie-XIAO@users.noreply.github.com> Date: Mon, 29 Jan 2024 20:11:51 +0800 Subject: [PATCH 05/13] fix tests --- sklearn/ensemble/_gb.py | 56 ++-- .../tests/test_monotonic_constraints.py | 289 ++++++++++++++++++ sklearn/tree/tests/test_monotonic_tree.py | 14 +- 3 files changed, 320 insertions(+), 39 deletions(-) create mode 100644 sklearn/ensemble/tests/test_monotonic_constraints.py diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 802981c8907a2..67fa4e80a7dc2 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -50,7 +50,11 @@ from ..utils._param_validation import HasMethods, Interval, StrOptions from ..utils.multiclass import check_classification_targets from ..utils.stats import _weighted_percentile -from ..utils.validation import _check_sample_weight, check_is_fitted +from ..utils.validation import ( + _check_monotonic_cst, + _check_sample_weight, + check_is_fitted, +) from ._base import BaseEnsemble from ._gradient_boosting import _random_sample_mask, predict_stage, predict_stages @@ -468,6 +472,8 @@ def _fit_stage( else: neg_g_view = neg_gradient + monotonic_cst = _check_monotonic_cst(self, self.monotonic_cst) + for k in range(self.n_trees_per_iteration_): if self._loss.is_multiclass: y = np.array(original_y == k, dtype=np.float64) @@ -485,7 +491,7 @@ def _fit_stage( max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, ccp_alpha=self.ccp_alpha, - monotonic_cst=self.monotonic_cst, + monotonic_cst=monotonic_cst, ) if self.subsample < 1.0: @@ -1330,21 +1336,20 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): .. versionadded:: 0.22 - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonicity constraint to enforce on each feature. - - 1: monotonic increase - - 0: no constraint - - -1: monotonic decrease + monotonic_cst : array-like of int of shape (n_features) or dict, default=None + Monotonic constraint to enforce on each feature are specified using the + following integer values: - If ``monotonic_cst`` is ``None``, no constraints are applied. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease - Monotonicity constraints are not supported for: - - multiclass classifications (i.e. when `n_classes > 2`), - - multioutput classifications (i.e. when `n_outputs_ > 1`), - - classifications trained on data with missing values. - - The constraints hold over the probability of the positive class. + If a dict with str keys, map feature to monotonic constraints by name. + If an array, the features are mapped to constraints by position. See + :ref:`monotonic_cst_features_names` for a usage example. + The constraints are only valid for binary classifications and hold + over the probability of the positive class. Read more in the :ref:`User Guide `. .. versionadded:: 1.4 @@ -1969,21 +1974,20 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): .. versionadded:: 0.22 - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonicity constraint to enforce on each feature. - - 1: monotonic increase - - 0: no constraint - - -1: monotonic decrease - - If ``monotonic_cst`` is ``None``, no constraints are applied. + monotonic_cst : array-like of int of shape (n_features) or dict, default=None + Monotonic constraint to enforce on each feature are specified using the + following integer values: - Monotonicity constraints are not supported for: - - multiclass classifications (i.e. when `n_classes > 2`), - - multioutput classifications (i.e. when `n_outputs_ > 1`), - - classifications trained on data with missing values. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease - The constraints hold over the probability of the positive class. + If a dict with str keys, map feature to monotonic constraints by name. + If an array, the features are mapped to constraints by position. See + :ref:`monotonic_cst_features_names` for a usage example. + The constraints are only valid for binary classifications and hold + over the probability of the positive class. Read more in the :ref:`User Guide `. .. versionadded:: 1.4 diff --git a/sklearn/ensemble/tests/test_monotonic_constraints.py b/sklearn/ensemble/tests/test_monotonic_constraints.py new file mode 100644 index 0000000000000..1209fbb5116fa --- /dev/null +++ b/sklearn/ensemble/tests/test_monotonic_constraints.py @@ -0,0 +1,289 @@ +import re + +import numpy as np +import pytest +from numpy.testing import assert_array_almost_equal + +from sklearn.datasets import make_classification, make_regression +from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor +from sklearn.utils._testing import _convert_container +from sklearn.utils.fixes import CSC_CONTAINERS + + +@pytest.mark.parametrize("depth_first_builder", (True, False)) +@pytest.mark.parametrize("criterion", ("friedman_mse", "squared_error")) +@pytest.mark.parametrize("csc_container", [None] + CSC_CONTAINERS) +def test_monotonic_constraints_regressor( + depth_first_builder, criterion, global_random_seed, csc_container +): + n_samples = 1000 + n_samples_train = 900 + # Build a regression task using 5 informative features + X, y = make_regression( + n_samples=n_samples, + n_features=5, + n_informative=5, + random_state=global_random_seed, + ) + train = np.arange(n_samples_train) + test = np.arange(n_samples_train, n_samples) + X_train = X[train] + y_train = y[train] + X_test = np.copy(X[test]) + X_test_incr = np.copy(X_test) + X_test_decr = np.copy(X_test) + X_test_incr[:, 0] += 10 + X_test_decr[:, 1] += 10 + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = 1 + monotonic_cst[1] = -1 + + if csc_container is not None: + X_train = csc_container(X_train) + + params = { + "monotonic_cst": monotonic_cst, + "criterion": criterion, + "random_state": global_random_seed, + "n_estimators": 5, + } + if depth_first_builder: + params["max_depth"] = None + else: + params["max_depth"] = 8 + params["max_leaf_nodes"] = n_samples_train + + est = GradientBoostingRegressor(**params) + est.fit(X_train, y_train) + y = est.predict(X_test) + + # Monotonic increase constraint + y_incr = est.predict(X_test_incr) + # y_incr should always be greater than y + assert np.all(y_incr >= y) + + # Monotonic decrease constraint + y_decr = est.predict(X_test_decr) + # y_decr should always be lower than y + assert np.all(y_decr <= y) + + +@pytest.mark.parametrize("depth_first_builder", (True, False)) +@pytest.mark.parametrize("csc_container", [None] + CSC_CONTAINERS) +def test_monotonic_constraints_classifier( + depth_first_builder, global_random_seed, csc_container +): + n_samples = 1000 + n_samples_train = 900 + X, y = make_classification( + n_samples=n_samples, + n_classes=2, + n_features=5, + n_informative=5, + n_redundant=0, + random_state=global_random_seed, + ) + X_train, y_train = X[:n_samples_train], y[:n_samples_train] + X_test, _ = X[n_samples_train:], y[n_samples_train:] + + X_test_0incr, X_test_0decr = np.copy(X_test), np.copy(X_test) + X_test_1incr, X_test_1decr = np.copy(X_test), np.copy(X_test) + X_test_0incr[:, 0] += 10 + X_test_0decr[:, 0] -= 10 + X_test_1incr[:, 1] += 10 + X_test_1decr[:, 1] -= 10 + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = 1 + monotonic_cst[1] = -1 + + if csc_container is not None: + X_train = csc_container(X_train) + + params = { + "monotonic_cst": monotonic_cst, + "random_state": global_random_seed, + "n_estimators": 5, + } + if depth_first_builder: + params["max_depth"] = None + else: + params["max_depth"] = 8 + params["max_leaf_nodes"] = n_samples_train + + est = GradientBoostingClassifier(**params) + est.fit(X_train, y_train) + proba_test = est.predict_proba(X_test) + + assert np.logical_and( + proba_test >= 0.0, proba_test <= 1.0 + ).all(), "Probability should always be in [0, 1] range." + assert_array_almost_equal(proba_test.sum(axis=1), 1.0) + + # Monotonic increase constraint, it applies to the positive class + assert np.all(est.predict_proba(X_test_0incr)[:, 1] >= proba_test[:, 1]) + assert np.all(est.predict_proba(X_test_0decr)[:, 1] <= proba_test[:, 1]) + + # Monotonic decrease constraint, it applies to the positive class + assert np.all(est.predict_proba(X_test_1incr)[:, 1] <= proba_test[:, 1]) + assert np.all(est.predict_proba(X_test_1decr)[:, 1] >= proba_test[:, 1]) + + +@pytest.mark.parametrize("use_feature_names", (True, False)) +def test_predictions(global_random_seed, use_feature_names): + # This is the same test as the one for HistGradientBoostingRegressor + rng = np.random.RandomState(global_random_seed) + + n_samples = 1000 + f_0 = rng.rand(n_samples) # positive correlation with y + f_1 = rng.rand(n_samples) # negative correslation with y + X = np.c_[f_0, f_1] + columns_name = ["f_0", "f_1"] + constructor_name = "dataframe" if use_feature_names else "array" + X = _convert_container(X, constructor_name, columns_name=columns_name) + + noise = rng.normal(loc=0.0, scale=0.01, size=n_samples) + y = 5 * f_0 + np.sin(10 * np.pi * f_0) - 5 * f_1 - np.cos(10 * np.pi * f_1) + noise + + if use_feature_names: + monotonic_cst = {"f_0": +1, "f_1": -1} + else: + monotonic_cst = [+1, -1] + + gbdt = GradientBoostingRegressor(monotonic_cst=monotonic_cst) + gbdt.fit(X, y) + + linspace = np.linspace(0, 1, 100) + sin = np.sin(linspace) + constant = np.full_like(linspace, fill_value=0.5) + + # First feature (POS) + # assert pred is all increasing when f_0 is all increasing + X = np.c_[linspace, constant] + X = _convert_container(X, constructor_name, columns_name=columns_name) + pred = gbdt.predict(X) + assert (np.diff(pred) >= 0.0).all() + # assert pred actually follows the variations of f_0 + X = np.c_[sin, constant] + X = _convert_container(X, constructor_name, columns_name=columns_name) + pred = gbdt.predict(X) + assert np.all((np.diff(pred) >= 0) == (np.diff(sin) >= 0)) + + # Second feature (NEG) + # assert pred is all decreasing when f_1 is all increasing + X = np.c_[constant, linspace] + X = _convert_container(X, constructor_name, columns_name=columns_name) + pred = gbdt.predict(X) + assert (np.diff(pred) <= 0.0).all() + # assert pred actually follows the inverse variations of f_1 + X = np.c_[constant, sin] + X = _convert_container(X, constructor_name, columns_name=columns_name) + pred = gbdt.predict(X) + assert ((np.diff(pred) <= 0) == (np.diff(sin) >= 0)).all() + + +def test_multiclass_raises(): + X, y = make_classification( + n_samples=100, n_features=5, n_classes=3, n_informative=3, random_state=0 + ) + y[0] = 0 + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = -1 + monotonic_cst[1] = 1 + est = GradientBoostingClassifier( + max_depth=None, monotonic_cst=monotonic_cst, random_state=0 + ) + + msg = "Monotonicity constraints are not supported with multiclass classification" + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + + +def test_multiple_output_raises(): + X = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] + y = [[1, 0, 1, 0, 1], [1, 0, 1, 0, 1]] + + est = GradientBoostingClassifier( + max_depth=None, monotonic_cst=np.array([-1, 1]), random_state=0 + ) + msg = "Monotonicity constraints are not supported with multiple output" + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + + +def test_missing_values_raises(): + X, y = make_classification( + n_samples=100, n_features=5, n_classes=2, n_informative=3, random_state=0 + ) + X[0, 0] = np.nan + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = 1 + est = GradientBoostingClassifier( + max_depth=None, monotonic_cst=monotonic_cst, random_state=0 + ) + + msg = "Input X contains NaN" + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + + +def test_bad_monotonic_cst_raises(): + X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] + y = [1, 0, 1, 0, 1] + + msg = "monotonic_cst has shape (3,) but the input data X has 2 features." + est = GradientBoostingClassifier( + max_depth=None, monotonic_cst=np.array([-1, 1, 0]), random_state=0 + ) + with pytest.raises(ValueError, match=re.escape(msg)): + est.fit(X, y) + + msg = "monotonic_cst must be an array-like of -1, 0 or 1." + est = GradientBoostingClassifier( + max_depth=None, monotonic_cst=np.array([-2, 2]), random_state=0 + ) + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + + est = GradientBoostingClassifier( + max_depth=None, monotonic_cst=np.array([-1, 0.8]), random_state=0 + ) + with pytest.raises(ValueError, match=msg + "(.*)0.8]"): + est.fit(X, y) + + +def test_bad_monotonic_cst_related_to_feature_names(): + pd = pytest.importorskip("pandas") + X = pd.DataFrame({"a": [0, 1, 2], "b": [0, 1, 2]}) + y = np.array([0, 1, 0]) + + monotonic_cst = {"d": 1, "a": 1, "c": -1} + gbdt = GradientBoostingRegressor(monotonic_cst=monotonic_cst) + expected_msg = re.escape( + "monotonic_cst contains 2 unexpected feature names: ['c', 'd']." + ) + with pytest.raises(ValueError, match=expected_msg): + gbdt.fit(X, y) + + monotonic_cst = {k: 1 for k in "abcdefghijklmnopqrstuvwxyz"} + gbdt = GradientBoostingRegressor(monotonic_cst=monotonic_cst) + expected_msg = re.escape( + "monotonic_cst contains 24 unexpected feature names: " + "['c', 'd', 'e', 'f', 'g', '...']." + ) + with pytest.raises(ValueError, match=expected_msg): + gbdt.fit(X, y) + + monotonic_cst = {"a": 1} + gbdt = GradientBoostingRegressor(monotonic_cst=monotonic_cst) + expected_msg = re.escape( + "GradientBoostingRegressor was not fitted on data with feature " + "names. Pass monotonic_cst as an integer array instead." + ) + with pytest.raises(ValueError, match=expected_msg): + gbdt.fit(X.values, y) + + monotonic_cst = {"b": -1, "a": "+"} + gbdt = GradientBoostingRegressor(monotonic_cst=monotonic_cst) + expected_msg = re.escape("monotonic_cst['a'] must be either -1, 0 or 1. Got '+'.") + with pytest.raises(ValueError, match=expected_msg): + gbdt.fit(X, y) diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py index 3ae767022fa4a..6478c2e2dfd85 100644 --- a/sklearn/tree/tests/test_monotonic_tree.py +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -5,8 +5,6 @@ from sklearn.ensemble import ( ExtraTreesClassifier, ExtraTreesRegressor, - GradientBoostingClassifier, - GradientBoostingRegressor, RandomForestClassifier, RandomForestRegressor, ) @@ -24,12 +22,10 @@ TREE_BASED_CLASSIFIER_CLASSES = TREE_CLASSIFIER_CLASSES + [ RandomForestClassifier, ExtraTreesClassifier, - GradientBoostingClassifier, ] TREE_BASED_REGRESSOR_CLASSES = TREE_REGRESSOR_CLASSES + [ RandomForestRegressor, ExtraTreesRegressor, - GradientBoostingRegressor, ] @@ -101,9 +97,7 @@ def test_monotonic_constraints_classifications( @pytest.mark.parametrize("TreeRegressor", TREE_BASED_REGRESSOR_CLASSES) @pytest.mark.parametrize("depth_first_builder", (True, False)) @pytest.mark.parametrize("sparse_splitter", (True, False)) -@pytest.mark.parametrize( - "criterion", ("friedman_mse", "squared_error", "absolute_error") -) +@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error")) @pytest.mark.parametrize("csc_container", CSC_CONTAINERS) def test_monotonic_constraints_regressions( TreeRegressor, @@ -113,12 +107,6 @@ def test_monotonic_constraints_regressions( global_random_seed, csc_container, ): - if ( - criterion in ("absolute_error", "poisson") - and TreeRegressor is GradientBoostingRegressor - ): - pytest.skip(f"{TreeRegressor.__name__} does not support criterion={criterion}") - n_samples = 1000 n_samples_train = 900 # Build a regression task using 5 informative features From cbeb2724cd5ea42c0ba23f715a4c046c422bae29 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Sun, 21 Jul 2024 02:10:41 +0800 Subject: [PATCH 06/13] record boundary info in nodes --- sklearn/ensemble/_gb.py | 12 ++++++------ sklearn/tree/_tree.pxd | 7 +++++-- sklearn/tree/_tree.pyx | 24 ++++++++++++++++++++---- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index be349776cebf9..e1bb0af2467fc 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -140,7 +140,6 @@ def _update_terminal_regions( sample_mask, learning_rate=0.1, k=0, - line_search=True, ): """Update the leaf values to be predicted by the tree and raw_prediction. @@ -183,14 +182,11 @@ def _update_terminal_regions( ``learning_rate``. k : int, default=0 The index of the estimator being updated. - line_search : bool, default=True - Whether line search must be performed. Line search must not be - performed under monotonic constraints. """ # compute leaf for each sample in ``X``. terminal_regions = tree.apply(X) - if line_search and not isinstance(loss, HalfSquaredError): + if not isinstance(loss, HalfSquaredError): # mask all which are not in sample mask. masked_terminal_regions = terminal_regions.copy() masked_terminal_regions[~sample_mask] = -1 @@ -262,6 +258,11 @@ def compute_update(y_, indices, neg_gradient, raw_prediction, k): sw = None if sample_weight is None else sample_weight[indices] update = compute_update(y_, indices, neg_gradient, raw_prediction, k) + if update > tree.upper_bound[leaf]: + update = tree.upper_bound[leaf] + elif update < tree.lower_bound[leaf]: + update = tree.lower_bound[leaf] + # TODO: Multiply here by learning rate instead of everywhere else. tree.value[leaf, 0, 0] = update @@ -515,7 +516,6 @@ def _fit_stage( sample_mask, learning_rate=self.learning_rate, k=k, - line_search=self.monotonic_cst is None, ) # add tree to ensemble diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 831ca38a11148..50b63a354b043 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -22,7 +22,8 @@ cdef struct Node: intp_t n_node_samples # Number of samples at the node float64_t weighted_n_node_samples # Weighted number of samples at the node uint8_t missing_go_to_left # Whether features have missing values - + float64_t lower_bound # Lower bound of the node's impurity + float64_t upper_bound # Upper bound of the node's impurity cdef struct ParentInfo: # Structure to store information about the parent of a node @@ -58,7 +59,9 @@ cdef class Tree: intp_t feature, float64_t threshold, float64_t impurity, intp_t n_node_samples, float64_t weighted_n_node_samples, - uint8_t missing_go_to_left) except -1 nogil + uint8_t missing_go_to_left, + float64_t lower_bound, + float64_t upper_bound) except -1 nogil cdef int _resize(self, intp_t capacity) except -1 nogil cdef int _resize_c(self, intp_t capacity=*) except -1 nogil diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 43b7770131497..75d6a74fd61ac 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -268,7 +268,9 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): node_id = tree._add_node(parent, is_left, is_leaf, split.feature, split.threshold, parent_record.impurity, n_node_samples, weighted_n_node_samples, - split.missing_go_to_left) + split.missing_go_to_left, + parent_record.lower_bound, + parent_record.upper_bound) if node_id == INTPTR_MAX: rc = -1 @@ -626,7 +628,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): is_left, is_leaf, split.feature, split.threshold, parent_record.impurity, n_node_samples, weighted_n_node_samples, - split.missing_go_to_left) + split.missing_go_to_left, + parent_record.lower_bound, parent_record.upper_bound) if node_id == INTPTR_MAX: return -1 @@ -774,6 +777,14 @@ cdef class Tree: def missing_go_to_left(self): return self._get_node_ndarray()['missing_go_to_left'][:self.node_count] + @property + def lower_bound(self): + return self._get_node_ndarray()['lower_bound'][:self.node_count] + + @property + def upper_bound(self): + return self._get_node_ndarray()['upper_bound'][:self.node_count] + @property def value(self): return self._get_value_ndarray()[:self.node_count] @@ -910,7 +921,9 @@ cdef class Tree: intp_t feature, float64_t threshold, float64_t impurity, intp_t n_node_samples, float64_t weighted_n_node_samples, - uint8_t missing_go_to_left) except -1 nogil: + uint8_t missing_go_to_left, + float64_t lower_bound, + float64_t upper_bound) except -1 nogil: """Add a node to the tree. The new node registers itself as the child of its parent. @@ -927,6 +940,8 @@ cdef class Tree: node.impurity = impurity node.n_node_samples = n_node_samples node.weighted_n_node_samples = weighted_n_node_samples + node.lower_bound = lower_bound + node.upper_bound = upper_bound if parent != _TREE_UNDEFINED: if is_left: @@ -1934,7 +1949,8 @@ cdef _build_pruned_tree( new_node_id = tree._add_node( parent, is_left, is_leaf, node.feature, node.threshold, node.impurity, node.n_node_samples, - node.weighted_n_node_samples, node.missing_go_to_left) + node.weighted_n_node_samples, node.missing_go_to_left, + node.lower_bound, node.upper_bound) if new_node_id == INTPTR_MAX: rc = -1 From 5f6b6d8b8161863fda88616d9123f2d32f34d2d5 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Sun, 21 Jul 2024 12:12:31 +0800 Subject: [PATCH 07/13] change versionadded --- sklearn/ensemble/_gb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index e1bb0af2467fc..d1d85a56ce13a 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -1349,7 +1349,7 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): over the probability of the positive class. Read more in the :ref:`User Guide `. - .. versionadded:: 1.4 + .. versionadded:: 1.6 Attributes ---------- @@ -1987,7 +1987,7 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): over the probability of the positive class. Read more in the :ref:`User Guide `. - .. versionadded:: 1.4 + .. versionadded:: 1.6 Attributes ---------- From 3ac75ead3ecd24d313a43d82b8d6c1af698e22b8 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 13 Aug 2024 01:36:50 +0800 Subject: [PATCH 08/13] avoid modifying the node structure --- sklearn/ensemble/_gb.py | 32 +++++++++++--- sklearn/tree/_classes.py | 21 +++++++-- sklearn/tree/_tree.pxd | 10 ++++- sklearn/tree/_tree.pyx | 95 +++++++++++++++++++++++++++++++++++----- 4 files changed, 136 insertions(+), 22 deletions(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 078c22d992362..cfb172f94fe88 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -140,6 +140,8 @@ def _update_terminal_regions( sample_mask, learning_rate=0.1, k=0, + lower_bounds=None, + upper_bounds=None, ): """Update the leaf values to be predicted by the tree and raw_prediction. @@ -182,9 +184,14 @@ def _update_terminal_regions( ``learning_rate``. k : int, default=0 The index of the estimator being updated. + lower_bounds : ndarray of shape (node_count,), default=None + The lower bounds for the tree nodes. + upper_bounds : ndarray of shape (node_count,), default=None + The upper bounds for the tree nodes. """ # compute leaf for each sample in ``X``. terminal_regions = tree.apply(X) + use_bounds = upper_bounds is not None and lower_bounds is not None if not isinstance(loss, HalfSquaredError): # mask all which are not in sample mask. @@ -258,10 +265,11 @@ def compute_update(y_, indices, neg_gradient, raw_prediction, k): sw = None if sample_weight is None else sample_weight[indices] update = compute_update(y_, indices, neg_gradient, raw_prediction, k) - if update > tree.upper_bound[leaf]: - update = tree.upper_bound[leaf] - elif update < tree.lower_bound[leaf]: - update = tree.lower_bound[leaf] + if use_bounds: + if update > upper_bounds[leaf]: + update = upper_bounds[leaf] + elif update < lower_bounds[leaf]: + update = lower_bounds[leaf] # TODO: Multiply here by learning rate instead of everywhere else. tree.value[leaf, 0, 0] = update @@ -499,10 +507,20 @@ def _fit_stage( sample_weight = sample_weight * sample_mask.astype(np.float64) X = X_csc if X_csc is not None else X - tree.fit( - X, neg_g_view[:, k], sample_weight=sample_weight, check_input=False + tree._fit( + X, + neg_g_view[:, k], + sample_weight=sample_weight, + check_input=False, + record_node_boundaries=self.monotonic_cst is not None, ) + if self.monotonic_cst is None: + lower_bounds, upper_bounds = None, None + else: + lower_bounds = tree.tree_.lower_bounds + upper_bounds = tree.tree_.upper_bounds + # update tree leaves X_for_tree_update = X_csr if X_csr is not None else X _update_terminal_regions( @@ -516,6 +534,8 @@ def _fit_stage( sample_mask, learning_rate=self.learning_rate, k=k, + lower_bounds=lower_bounds, + upper_bounds=upper_bounds, ) # add tree to ensemble diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index efa5eb6e8f84d..8854b4adaa117 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -227,6 +227,7 @@ def _fit( sample_weight=None, check_input=True, missing_values_in_feature_mask=None, + record_node_boundaries=False, ): random_state = check_random_state(self.random_state) @@ -431,13 +432,19 @@ def _fit( ) if is_classifier(self): - self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_) + self.tree_ = Tree( + self.n_features_in_, + self.n_classes_, + self.n_outputs_, + record_node_boundaries=record_node_boundaries, + ) else: self.tree_ = Tree( self.n_features_in_, # TODO: tree shouldn't need this in this case np.array([1] * self.n_outputs_, dtype=np.intp), self.n_outputs_, + record_node_boundaries=record_node_boundaries, ) # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise @@ -467,7 +474,7 @@ def _fit( self.n_classes_ = self.n_classes_[0] self.classes_ = self.classes_[0] - self._prune_tree() + self._prune_tree(record_node_boundaries=record_node_boundaries) return self @@ -598,7 +605,7 @@ def decision_path(self, X, check_input=True): X = self._validate_X_predict(X, check_input) return self.tree_.decision_path(X) - def _prune_tree(self): + def _prune_tree(self, record_node_boundaries=False): """Prune tree using Minimal Cost-Complexity Pruning.""" check_is_fitted(self) @@ -608,13 +615,19 @@ def _prune_tree(self): # build pruned tree if is_classifier(self): n_classes = np.atleast_1d(self.n_classes_) - pruned_tree = Tree(self.n_features_in_, n_classes, self.n_outputs_) + pruned_tree = Tree( + self.n_features_in_, + n_classes, + self.n_outputs_, + record_node_boundaries=record_node_boundaries, + ) else: pruned_tree = Tree( self.n_features_in_, # TODO: the tree shouldn't need this param np.array([1] * self.n_outputs_, dtype=np.intp), self.n_outputs_, + record_node_boundaries=record_node_boundaries, ) _build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 50b63a354b043..1de189bfd93aa 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -22,8 +22,6 @@ cdef struct Node: intp_t n_node_samples # Number of samples at the node float64_t weighted_n_node_samples # Weighted number of samples at the node uint8_t missing_go_to_left # Whether features have missing values - float64_t lower_bound # Lower bound of the node's impurity - float64_t upper_bound # Upper bound of the node's impurity cdef struct ParentInfo: # Structure to store information about the parent of a node @@ -54,6 +52,12 @@ cdef class Tree: cdef float64_t* value # (capacity, n_outputs, max_n_classes) array of values cdef intp_t value_stride # = n_outputs * max_n_classes + # Lower and upper boundaries of nodes, used for monotonic constraints of gradient + # boosting; they are left uninitialized otherwise + cdef bint record_node_boundaries # Whether to record the node boundaries + cdef float64_t* lower_bounds # Array of lower boundaries of nodes + cdef float64_t* upper_bounds # Array of upper boundaries of nodes + # Methods cdef intp_t _add_node(self, intp_t parent, bint is_left, bint is_leaf, intp_t feature, float64_t threshold, float64_t impurity, @@ -67,6 +71,8 @@ cdef class Tree: cdef cnp.ndarray _get_value_ndarray(self) cdef cnp.ndarray _get_node_ndarray(self) + cdef cnp.ndarray _get_lower_bounds_ndarray(self) + cdef cnp.ndarray _get_upper_bounds_ndarray(self) cpdef cnp.ndarray predict(self, object X) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 75d6a74fd61ac..ca9f8fcf13201 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -778,20 +778,25 @@ cdef class Tree: return self._get_node_ndarray()['missing_go_to_left'][:self.node_count] @property - def lower_bound(self): - return self._get_node_ndarray()['lower_bound'][:self.node_count] + def value(self): + return self._get_value_ndarray()[:self.node_count] @property - def upper_bound(self): - return self._get_node_ndarray()['upper_bound'][:self.node_count] + def lower_bounds(self): + if not self.record_node_boundaries: + raise ValueError("Tree was not built with record_node_boundaries=True") + return self._get_lower_bounds_ndarray()[:self.node_count] @property - def value(self): - return self._get_value_ndarray()[:self.node_count] + def upper_bounds(self): + if not self.record_node_boundaries: + raise ValueError("Tree was not built with record_node_boundaries=True") + return self._get_upper_bounds_ndarray()[:self.node_count] # TODO: Convert n_classes to cython.integral memory view once # https://github.com/cython/cython/issues/5243 is fixed - def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs): + def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs, + bool record_node_boundaries=False): """Constructor.""" cdef intp_t dummy = 0 size_t_dtype = np.array(dummy).dtype @@ -817,6 +822,10 @@ cdef class Tree: self.capacity = 0 self.value = NULL self.nodes = NULL + self.lower_bounds = NULL + self.upper_bounds = NULL + + self.record_node_boundaries = record_node_boundaries def __dealloc__(self): """Destructor.""" @@ -824,6 +833,8 @@ cdef class Tree: free(self.n_classes) free(self.value) free(self.nodes) + free(self.lower_bounds) + free(self.upper_bounds) def __reduce__(self): """Reduce re-implementation, for pickling.""" @@ -839,12 +850,19 @@ cdef class Tree: d["node_count"] = self.node_count d["nodes"] = self._get_node_ndarray() d["values"] = self._get_value_ndarray() + + d["record_node_boundaries"] = self.record_node_boundaries + if self.record_node_boundaries: + d["lower_bounds"] = self._get_lower_bounds_ndarray() + d["upper_bounds"] = self._get_upper_bounds_ndarray() + return d def __setstate__(self, d): """Setstate re-implementation, for unpickling.""" self.max_depth = d["max_depth"] self.node_count = d["node_count"] + self.record_node_boundaries = d["record_node_boundaries"] if 'nodes' not in d: raise ValueError('You have loaded Tree version which ' @@ -872,6 +890,15 @@ cdef class Tree: memcpy(self.value, cnp.PyArray_DATA(value_ndarray), self.capacity * self.value_stride * sizeof(float64_t)) + if self.record_node_boundaries: + lower_bounds_ndarray = d["lower_bounds"] + upper_bounds_ndarray = d["upper_bounds"] + + memcpy(self.lower_bounds, cnp.PyArray_DATA(lower_bounds_ndarray), + self.capacity * sizeof(float64_t)) + memcpy(self.upper_bounds, cnp.PyArray_DATA(upper_bounds_ndarray), + self.capacity * sizeof(float64_t)) + cdef int _resize(self, intp_t capacity) except -1 nogil: """Resize all inner arrays to `capacity`, if `capacity` == -1, then double the size of the inner arrays. @@ -901,6 +928,9 @@ cdef class Tree: safe_realloc(&self.nodes, capacity) safe_realloc(&self.value, capacity * self.value_stride) + if self.record_node_boundaries: + safe_realloc(&self.lower_bounds, capacity) + safe_realloc(&self.upper_bounds, capacity) if capacity > self.capacity: # value memory is initialised to 0 to enable classifier argmax @@ -910,6 +940,13 @@ cdef class Tree: # node memory is initialised to 0 to ensure deterministic pickle (padding in Node struct) memset((self.nodes + self.capacity), 0, (capacity - self.capacity) * sizeof(Node)) + if self.record_node_boundaries: + # node boundaries are initialised to 0 to ensure deterministic pickle + memset((self.lower_bounds + self.capacity), 0, + (capacity - self.capacity) * sizeof(float64_t)) + memset((self.upper_bounds + self.capacity), 0, + (capacity - self.capacity) * sizeof(float64_t)) + # if capacity smaller than node_count, adjust the counter if capacity < self.node_count: self.node_count = capacity @@ -940,8 +977,6 @@ cdef class Tree: node.impurity = impurity node.n_node_samples = n_node_samples node.weighted_n_node_samples = weighted_n_node_samples - node.lower_bound = lower_bound - node.upper_bound = upper_bound if parent != _TREE_UNDEFINED: if is_left: @@ -961,6 +996,10 @@ cdef class Tree: node.threshold = threshold node.missing_go_to_left = missing_go_to_left + if self.record_node_boundaries: + self.lower_bounds[node_id] = lower_bound + self.upper_bounds[node_id] = upper_bound + self.node_count += 1 return node_id @@ -1340,6 +1379,36 @@ cdef class Tree: raise ValueError("Can't initialize array.") return arr + cdef cnp.ndarray _get_lower_bounds_ndarray(self): + """Wraps lower bounds as a NumPy array. + + The array keeps a reference to this Tree, which manages the underlying + memory. + """ + cdef cnp.npy_intp shape[1] + shape[0] = self.node_count + cdef cnp.ndarray arr + arr = cnp.PyArray_SimpleNewFromData(1, shape, cnp.NPY_DOUBLE, self.lower_bounds) + Py_INCREF(self) + if PyArray_SetBaseObject(arr, self) < 0: + raise ValueError("Can't initialize array.") + return arr + + cdef cnp.ndarray _get_upper_bounds_ndarray(self): + """Wraps upper bounds as a NumPy array. + + The array keeps a reference to this Tree, which manages the underlying + memory. + """ + cdef cnp.npy_intp shape[1] + shape[0] = self.node_count + cdef cnp.ndarray arr + arr = cnp.PyArray_SimpleNewFromData(1, shape, cnp.NPY_DOUBLE, self.upper_bounds) + Py_INCREF(self) + if PyArray_SetBaseObject(arr, self) < 0: + raise ValueError("Can't initialize array.") + return arr + def compute_partial_dependence(self, float32_t[:, ::1] X, const intp_t[::1] target_features, float64_t[::1] out): @@ -1927,6 +1996,9 @@ cdef _build_pruned_tree( float64_t* orig_value_ptr float64_t* new_value_ptr + float64_t lower_bound + float64_t upper_bound + stack[BuildPrunedRecord] prune_stack BuildPrunedRecord stack_record @@ -1946,11 +2018,14 @@ cdef _build_pruned_tree( is_leaf = leaves_in_subtree[orig_node_id] node = &orig_tree.nodes[orig_node_id] + lower_bound = orig_tree.lower_bounds[orig_node_id] + upper_bound = orig_tree.lower_bounds[orig_node_id] + new_node_id = tree._add_node( parent, is_left, is_leaf, node.feature, node.threshold, node.impurity, node.n_node_samples, node.weighted_n_node_samples, node.missing_go_to_left, - node.lower_bound, node.upper_bound) + lower_bound, upper_bound) if new_node_id == INTPTR_MAX: rc = -1 From d141301d0afa057a3cd88ead563c92b40a889a00 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 13 Aug 2024 12:17:34 +0800 Subject: [PATCH 09/13] try fix segfault --- sklearn/tree/_tree.pyx | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index ca9f8fcf13201..9ffd53b017e53 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1996,6 +1996,7 @@ cdef _build_pruned_tree( float64_t* orig_value_ptr float64_t* new_value_ptr + bint record_node_boundaries = tree.record_node_boundaries and orig_tree.record_node_boundaries float64_t lower_bound float64_t upper_bound @@ -2018,8 +2019,12 @@ cdef _build_pruned_tree( is_leaf = leaves_in_subtree[orig_node_id] node = &orig_tree.nodes[orig_node_id] - lower_bound = orig_tree.lower_bounds[orig_node_id] - upper_bound = orig_tree.lower_bounds[orig_node_id] + if record_node_boundaries: + lower_bound = orig_tree.lower_bounds[orig_node_id] + upper_bound = orig_tree.lower_bounds[orig_node_id] + else: + lower_bound = -INFINITY + upper_bound = INFINITY new_node_id = tree._add_node( parent, is_left, is_leaf, node.feature, node.threshold, From e26e12dc5561ad8b3af8aea8d96f13691299f199 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Wed, 4 Sep 2024 13:55:00 -0400 Subject: [PATCH 10/13] revert changes and add comment in docs --- doc/whats_new/v1.6.rst | 6 - sklearn/ensemble/_gb.py | 113 +------ .../tests/test_monotonic_constraints.py | 289 ------------------ sklearn/tree/_classes.py | 21 +- sklearn/tree/_tree.pxd | 13 +- sklearn/tree/_tree.pyx | 106 +------ 6 files changed, 24 insertions(+), 524 deletions(-) delete mode 100644 sklearn/ensemble/tests/test_monotonic_constraints.py diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index eafd9b3201bfe..d4d373e9bf280 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -237,12 +237,6 @@ Changelog right child node as the tree is traversed. :pr:`28268` by :user:`Adam Li `. -- |Feature| :class:`ensemble.GradientBoostingClassifier` and - :class:`ensemble.GradientBoostingRegressor` now support monotonic constraints, - useful when features are supposed to have a positive/negative effect on the target. - Missing values in the train data and multi-output targets are not supported. - :pr:`27516` by :user:`Yao Xiao `. - :mod:`sklearn.impute` ..................... diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index cfb172f94fe88..d1eeafa00dccd 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -49,11 +49,7 @@ from ..utils._param_validation import HasMethods, Interval, StrOptions from ..utils.multiclass import check_classification_targets from ..utils.stats import _weighted_percentile -from ..utils.validation import ( - _check_monotonic_cst, - _check_sample_weight, - check_is_fitted, -) +from ..utils.validation import _check_sample_weight, check_is_fitted from ._base import BaseEnsemble from ._gradient_boosting import _random_sample_mask, predict_stage, predict_stages @@ -140,8 +136,6 @@ def _update_terminal_regions( sample_mask, learning_rate=0.1, k=0, - lower_bounds=None, - upper_bounds=None, ): """Update the leaf values to be predicted by the tree and raw_prediction. @@ -184,14 +178,9 @@ def _update_terminal_regions( ``learning_rate``. k : int, default=0 The index of the estimator being updated. - lower_bounds : ndarray of shape (node_count,), default=None - The lower bounds for the tree nodes. - upper_bounds : ndarray of shape (node_count,), default=None - The upper bounds for the tree nodes. """ # compute leaf for each sample in ``X``. terminal_regions = tree.apply(X) - use_bounds = upper_bounds is not None and lower_bounds is not None if not isinstance(loss, HalfSquaredError): # mask all which are not in sample mask. @@ -265,12 +254,6 @@ def compute_update(y_, indices, neg_gradient, raw_prediction, k): sw = None if sample_weight is None else sample_weight[indices] update = compute_update(y_, indices, neg_gradient, raw_prediction, k) - if use_bounds: - if update > upper_bounds[leaf]: - update = upper_bounds[leaf] - elif update < lower_bounds[leaf]: - update = lower_bounds[leaf] - # TODO: Multiply here by learning rate instead of everywhere else. tree.value[leaf, 0, 0] = update @@ -383,6 +366,7 @@ class BaseGradientBoosting(BaseEnsemble, metaclass=ABCMeta): "tol": [Interval(Real, 0.0, None, closed="left")], } _parameter_constraints.pop("splitter") + _parameter_constraints.pop("monotonic_cst") @abstractmethod def __init__( @@ -409,7 +393,6 @@ def __init__( validation_fraction=0.1, n_iter_no_change=None, tol=1e-4, - monotonic_cst=None, ): self.n_estimators = n_estimators self.learning_rate = learning_rate @@ -432,7 +415,6 @@ def __init__( self.validation_fraction = validation_fraction self.n_iter_no_change = n_iter_no_change self.tol = tol - self.monotonic_cst = monotonic_cst @abstractmethod def _encode_y(self, y=None, sample_weight=None): @@ -480,8 +462,6 @@ def _fit_stage( else: neg_g_view = neg_gradient - monotonic_cst = _check_monotonic_cst(self, self.monotonic_cst) - for k in range(self.n_trees_per_iteration_): if self._loss.is_multiclass: y = np.array(original_y == k, dtype=np.float64) @@ -499,7 +479,6 @@ def _fit_stage( max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, ccp_alpha=self.ccp_alpha, - monotonic_cst=monotonic_cst, ) if self.subsample < 1.0: @@ -507,20 +486,10 @@ def _fit_stage( sample_weight = sample_weight * sample_mask.astype(np.float64) X = X_csc if X_csc is not None else X - tree._fit( - X, - neg_g_view[:, k], - sample_weight=sample_weight, - check_input=False, - record_node_boundaries=self.monotonic_cst is not None, + tree.fit( + X, neg_g_view[:, k], sample_weight=sample_weight, check_input=False ) - if self.monotonic_cst is None: - lower_bounds, upper_bounds = None, None - else: - lower_bounds = tree.tree_.lower_bounds - upper_bounds = tree.tree_.upper_bounds - # update tree leaves X_for_tree_update = X_csr if X_csr is not None else X _update_terminal_regions( @@ -534,8 +503,6 @@ def _fit_stage( sample_mask, learning_rate=self.learning_rate, k=k, - lower_bounds=lower_bounds, - upper_bounds=upper_bounds, ) # add tree to ensemble @@ -684,23 +651,13 @@ def fit(self, X, y, sample_weight=None, monitor=None): if not self.warm_start: self._clear_state() + # Check input # Since check_array converts both X and y to the same dtype, but the # trees use different types for X and y, checking them separately. + X, y = self._validate_data( - X, - y, - accept_sparse=["csr", "csc", "coo"], - dtype=DTYPE, - multi_output=True, + X, y, accept_sparse=["csr", "csc", "coo"], dtype=DTYPE, multi_output=True ) - - # Raise now instead of specifying multi_output=False because we want a more - # explicit error message. - if self.monotonic_cst is not None and len(y.shape) > 1: - raise ValueError( - "Monotonicity constraints are not supported with multiple outputs" - ) - sample_weight_is_none = sample_weight is None sample_weight = _check_sample_weight(sample_weight, X) if sample_weight_is_none: @@ -1166,8 +1123,9 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): classification is a special case where only a single regression tree is induced. - :class:`sklearn.ensemble.HistGradientBoostingClassifier` is a much faster - variant of this algorithm for intermediate datasets (`n_samples >= 10_000`). + :class:`~sklearn.ensemble.HistGradientBoostingClassifier` is a much faster + variant of this algorithm for intermediate datasets (`n_samples >= 10_000`) and + supports monotonicity constraints. Read more in the :ref:`User Guide `. @@ -1353,24 +1311,6 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): .. versionadded:: 0.22 - monotonic_cst : array-like of int of shape (n_features) or dict, default=None - Monotonic constraint to enforce on each feature are specified using the - following integer values: - - - 1: monotonic increase - - 0: no constraint - - -1: monotonic decrease - - If a dict with str keys, map feature to monotonic constraints by name. - If an array, the features are mapped to constraints by position. See - :ref:`monotonic_cst_features_names` for a usage example. - - The constraints are only valid for binary classifications and hold - over the probability of the positive class. - Read more in the :ref:`User Guide `. - - .. versionadded:: 1.6 - Attributes ---------- n_estimators_ : int @@ -1530,7 +1470,6 @@ def __init__( n_iter_no_change=None, tol=1e-4, ccp_alpha=0.0, - monotonic_cst=None, ): super().__init__( loss=loss, @@ -1553,7 +1492,6 @@ def __init__( n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha, - monotonic_cst=monotonic_cst, ) def _encode_y(self, y, sample_weight): @@ -1574,12 +1512,6 @@ def _encode_y(self, y, sample_weight): # expose n_classes_ attribute self.n_classes_ = n_classes - if self.monotonic_cst is not None and self.n_classes_ > 2: - raise ValueError( - "Monotonicity constraints are not supported with multiclass " - "classification" - ) - if sample_weight is None: n_trim_classes = n_classes else: @@ -1796,8 +1728,9 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): each stage a regression tree is fit on the negative gradient of the given loss function. - :class:`sklearn.ensemble.HistGradientBoostingRegressor` is a much faster - variant of this algorithm for intermediate datasets (`n_samples >= 10_000`). + :class:`~sklearn.ensemble.HistGradientBoostingRegressor` is a much faster + variant of this algorithm for intermediate datasets (`n_samples >= 10_000`) and + supports monotonic constraints. Read more in the :ref:`User Guide `. @@ -1991,24 +1924,6 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): .. versionadded:: 0.22 - monotonic_cst : array-like of int of shape (n_features) or dict, default=None - Monotonic constraint to enforce on each feature are specified using the - following integer values: - - - 1: monotonic increase - - 0: no constraint - - -1: monotonic decrease - - If a dict with str keys, map feature to monotonic constraints by name. - If an array, the features are mapped to constraints by position. See - :ref:`monotonic_cst_features_names` for a usage example. - - The constraints are only valid for binary classifications and hold - over the probability of the positive class. - Read more in the :ref:`User Guide `. - - .. versionadded:: 1.6 - Attributes ---------- n_estimators_ : int @@ -2157,7 +2072,6 @@ def __init__( n_iter_no_change=None, tol=1e-4, ccp_alpha=0.0, - monotonic_cst=None, ): super().__init__( loss=loss, @@ -2181,7 +2095,6 @@ def __init__( n_iter_no_change=n_iter_no_change, tol=tol, ccp_alpha=ccp_alpha, - monotonic_cst=monotonic_cst, ) def _encode_y(self, y=None, sample_weight=None): diff --git a/sklearn/ensemble/tests/test_monotonic_constraints.py b/sklearn/ensemble/tests/test_monotonic_constraints.py deleted file mode 100644 index 1209fbb5116fa..0000000000000 --- a/sklearn/ensemble/tests/test_monotonic_constraints.py +++ /dev/null @@ -1,289 +0,0 @@ -import re - -import numpy as np -import pytest -from numpy.testing import assert_array_almost_equal - -from sklearn.datasets import make_classification, make_regression -from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor -from sklearn.utils._testing import _convert_container -from sklearn.utils.fixes import CSC_CONTAINERS - - -@pytest.mark.parametrize("depth_first_builder", (True, False)) -@pytest.mark.parametrize("criterion", ("friedman_mse", "squared_error")) -@pytest.mark.parametrize("csc_container", [None] + CSC_CONTAINERS) -def test_monotonic_constraints_regressor( - depth_first_builder, criterion, global_random_seed, csc_container -): - n_samples = 1000 - n_samples_train = 900 - # Build a regression task using 5 informative features - X, y = make_regression( - n_samples=n_samples, - n_features=5, - n_informative=5, - random_state=global_random_seed, - ) - train = np.arange(n_samples_train) - test = np.arange(n_samples_train, n_samples) - X_train = X[train] - y_train = y[train] - X_test = np.copy(X[test]) - X_test_incr = np.copy(X_test) - X_test_decr = np.copy(X_test) - X_test_incr[:, 0] += 10 - X_test_decr[:, 1] += 10 - monotonic_cst = np.zeros(X.shape[1]) - monotonic_cst[0] = 1 - monotonic_cst[1] = -1 - - if csc_container is not None: - X_train = csc_container(X_train) - - params = { - "monotonic_cst": monotonic_cst, - "criterion": criterion, - "random_state": global_random_seed, - "n_estimators": 5, - } - if depth_first_builder: - params["max_depth"] = None - else: - params["max_depth"] = 8 - params["max_leaf_nodes"] = n_samples_train - - est = GradientBoostingRegressor(**params) - est.fit(X_train, y_train) - y = est.predict(X_test) - - # Monotonic increase constraint - y_incr = est.predict(X_test_incr) - # y_incr should always be greater than y - assert np.all(y_incr >= y) - - # Monotonic decrease constraint - y_decr = est.predict(X_test_decr) - # y_decr should always be lower than y - assert np.all(y_decr <= y) - - -@pytest.mark.parametrize("depth_first_builder", (True, False)) -@pytest.mark.parametrize("csc_container", [None] + CSC_CONTAINERS) -def test_monotonic_constraints_classifier( - depth_first_builder, global_random_seed, csc_container -): - n_samples = 1000 - n_samples_train = 900 - X, y = make_classification( - n_samples=n_samples, - n_classes=2, - n_features=5, - n_informative=5, - n_redundant=0, - random_state=global_random_seed, - ) - X_train, y_train = X[:n_samples_train], y[:n_samples_train] - X_test, _ = X[n_samples_train:], y[n_samples_train:] - - X_test_0incr, X_test_0decr = np.copy(X_test), np.copy(X_test) - X_test_1incr, X_test_1decr = np.copy(X_test), np.copy(X_test) - X_test_0incr[:, 0] += 10 - X_test_0decr[:, 0] -= 10 - X_test_1incr[:, 1] += 10 - X_test_1decr[:, 1] -= 10 - monotonic_cst = np.zeros(X.shape[1]) - monotonic_cst[0] = 1 - monotonic_cst[1] = -1 - - if csc_container is not None: - X_train = csc_container(X_train) - - params = { - "monotonic_cst": monotonic_cst, - "random_state": global_random_seed, - "n_estimators": 5, - } - if depth_first_builder: - params["max_depth"] = None - else: - params["max_depth"] = 8 - params["max_leaf_nodes"] = n_samples_train - - est = GradientBoostingClassifier(**params) - est.fit(X_train, y_train) - proba_test = est.predict_proba(X_test) - - assert np.logical_and( - proba_test >= 0.0, proba_test <= 1.0 - ).all(), "Probability should always be in [0, 1] range." - assert_array_almost_equal(proba_test.sum(axis=1), 1.0) - - # Monotonic increase constraint, it applies to the positive class - assert np.all(est.predict_proba(X_test_0incr)[:, 1] >= proba_test[:, 1]) - assert np.all(est.predict_proba(X_test_0decr)[:, 1] <= proba_test[:, 1]) - - # Monotonic decrease constraint, it applies to the positive class - assert np.all(est.predict_proba(X_test_1incr)[:, 1] <= proba_test[:, 1]) - assert np.all(est.predict_proba(X_test_1decr)[:, 1] >= proba_test[:, 1]) - - -@pytest.mark.parametrize("use_feature_names", (True, False)) -def test_predictions(global_random_seed, use_feature_names): - # This is the same test as the one for HistGradientBoostingRegressor - rng = np.random.RandomState(global_random_seed) - - n_samples = 1000 - f_0 = rng.rand(n_samples) # positive correlation with y - f_1 = rng.rand(n_samples) # negative correslation with y - X = np.c_[f_0, f_1] - columns_name = ["f_0", "f_1"] - constructor_name = "dataframe" if use_feature_names else "array" - X = _convert_container(X, constructor_name, columns_name=columns_name) - - noise = rng.normal(loc=0.0, scale=0.01, size=n_samples) - y = 5 * f_0 + np.sin(10 * np.pi * f_0) - 5 * f_1 - np.cos(10 * np.pi * f_1) + noise - - if use_feature_names: - monotonic_cst = {"f_0": +1, "f_1": -1} - else: - monotonic_cst = [+1, -1] - - gbdt = GradientBoostingRegressor(monotonic_cst=monotonic_cst) - gbdt.fit(X, y) - - linspace = np.linspace(0, 1, 100) - sin = np.sin(linspace) - constant = np.full_like(linspace, fill_value=0.5) - - # First feature (POS) - # assert pred is all increasing when f_0 is all increasing - X = np.c_[linspace, constant] - X = _convert_container(X, constructor_name, columns_name=columns_name) - pred = gbdt.predict(X) - assert (np.diff(pred) >= 0.0).all() - # assert pred actually follows the variations of f_0 - X = np.c_[sin, constant] - X = _convert_container(X, constructor_name, columns_name=columns_name) - pred = gbdt.predict(X) - assert np.all((np.diff(pred) >= 0) == (np.diff(sin) >= 0)) - - # Second feature (NEG) - # assert pred is all decreasing when f_1 is all increasing - X = np.c_[constant, linspace] - X = _convert_container(X, constructor_name, columns_name=columns_name) - pred = gbdt.predict(X) - assert (np.diff(pred) <= 0.0).all() - # assert pred actually follows the inverse variations of f_1 - X = np.c_[constant, sin] - X = _convert_container(X, constructor_name, columns_name=columns_name) - pred = gbdt.predict(X) - assert ((np.diff(pred) <= 0) == (np.diff(sin) >= 0)).all() - - -def test_multiclass_raises(): - X, y = make_classification( - n_samples=100, n_features=5, n_classes=3, n_informative=3, random_state=0 - ) - y[0] = 0 - monotonic_cst = np.zeros(X.shape[1]) - monotonic_cst[0] = -1 - monotonic_cst[1] = 1 - est = GradientBoostingClassifier( - max_depth=None, monotonic_cst=monotonic_cst, random_state=0 - ) - - msg = "Monotonicity constraints are not supported with multiclass classification" - with pytest.raises(ValueError, match=msg): - est.fit(X, y) - - -def test_multiple_output_raises(): - X = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] - y = [[1, 0, 1, 0, 1], [1, 0, 1, 0, 1]] - - est = GradientBoostingClassifier( - max_depth=None, monotonic_cst=np.array([-1, 1]), random_state=0 - ) - msg = "Monotonicity constraints are not supported with multiple output" - with pytest.raises(ValueError, match=msg): - est.fit(X, y) - - -def test_missing_values_raises(): - X, y = make_classification( - n_samples=100, n_features=5, n_classes=2, n_informative=3, random_state=0 - ) - X[0, 0] = np.nan - monotonic_cst = np.zeros(X.shape[1]) - monotonic_cst[0] = 1 - est = GradientBoostingClassifier( - max_depth=None, monotonic_cst=monotonic_cst, random_state=0 - ) - - msg = "Input X contains NaN" - with pytest.raises(ValueError, match=msg): - est.fit(X, y) - - -def test_bad_monotonic_cst_raises(): - X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] - y = [1, 0, 1, 0, 1] - - msg = "monotonic_cst has shape (3,) but the input data X has 2 features." - est = GradientBoostingClassifier( - max_depth=None, monotonic_cst=np.array([-1, 1, 0]), random_state=0 - ) - with pytest.raises(ValueError, match=re.escape(msg)): - est.fit(X, y) - - msg = "monotonic_cst must be an array-like of -1, 0 or 1." - est = GradientBoostingClassifier( - max_depth=None, monotonic_cst=np.array([-2, 2]), random_state=0 - ) - with pytest.raises(ValueError, match=msg): - est.fit(X, y) - - est = GradientBoostingClassifier( - max_depth=None, monotonic_cst=np.array([-1, 0.8]), random_state=0 - ) - with pytest.raises(ValueError, match=msg + "(.*)0.8]"): - est.fit(X, y) - - -def test_bad_monotonic_cst_related_to_feature_names(): - pd = pytest.importorskip("pandas") - X = pd.DataFrame({"a": [0, 1, 2], "b": [0, 1, 2]}) - y = np.array([0, 1, 0]) - - monotonic_cst = {"d": 1, "a": 1, "c": -1} - gbdt = GradientBoostingRegressor(monotonic_cst=monotonic_cst) - expected_msg = re.escape( - "monotonic_cst contains 2 unexpected feature names: ['c', 'd']." - ) - with pytest.raises(ValueError, match=expected_msg): - gbdt.fit(X, y) - - monotonic_cst = {k: 1 for k in "abcdefghijklmnopqrstuvwxyz"} - gbdt = GradientBoostingRegressor(monotonic_cst=monotonic_cst) - expected_msg = re.escape( - "monotonic_cst contains 24 unexpected feature names: " - "['c', 'd', 'e', 'f', 'g', '...']." - ) - with pytest.raises(ValueError, match=expected_msg): - gbdt.fit(X, y) - - monotonic_cst = {"a": 1} - gbdt = GradientBoostingRegressor(monotonic_cst=monotonic_cst) - expected_msg = re.escape( - "GradientBoostingRegressor was not fitted on data with feature " - "names. Pass monotonic_cst as an integer array instead." - ) - with pytest.raises(ValueError, match=expected_msg): - gbdt.fit(X.values, y) - - monotonic_cst = {"b": -1, "a": "+"} - gbdt = GradientBoostingRegressor(monotonic_cst=monotonic_cst) - expected_msg = re.escape("monotonic_cst['a'] must be either -1, 0 or 1. Got '+'.") - with pytest.raises(ValueError, match=expected_msg): - gbdt.fit(X, y) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 67f9ebe5ed51f..abfb836a6ec27 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -227,7 +227,6 @@ def _fit( sample_weight=None, check_input=True, missing_values_in_feature_mask=None, - record_node_boundaries=False, ): random_state = check_random_state(self.random_state) @@ -432,19 +431,13 @@ def _fit( ) if is_classifier(self): - self.tree_ = Tree( - self.n_features_in_, - self.n_classes_, - self.n_outputs_, - record_node_boundaries=record_node_boundaries, - ) + self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_) else: self.tree_ = Tree( self.n_features_in_, # TODO: tree shouldn't need this in this case np.array([1] * self.n_outputs_, dtype=np.intp), self.n_outputs_, - record_node_boundaries=record_node_boundaries, ) # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise @@ -474,7 +467,7 @@ def _fit( self.n_classes_ = self.n_classes_[0] self.classes_ = self.classes_[0] - self._prune_tree(record_node_boundaries=record_node_boundaries) + self._prune_tree() return self @@ -605,7 +598,7 @@ def decision_path(self, X, check_input=True): X = self._validate_X_predict(X, check_input) return self.tree_.decision_path(X) - def _prune_tree(self, record_node_boundaries=False): + def _prune_tree(self): """Prune tree using Minimal Cost-Complexity Pruning.""" check_is_fitted(self) @@ -615,19 +608,13 @@ def _prune_tree(self, record_node_boundaries=False): # build pruned tree if is_classifier(self): n_classes = np.atleast_1d(self.n_classes_) - pruned_tree = Tree( - self.n_features_in_, - n_classes, - self.n_outputs_, - record_node_boundaries=record_node_boundaries, - ) + pruned_tree = Tree(self.n_features_in_, n_classes, self.n_outputs_) else: pruned_tree = Tree( self.n_features_in_, # TODO: the tree shouldn't need this param np.array([1] * self.n_outputs_, dtype=np.intp), self.n_outputs_, - record_node_boundaries=record_node_boundaries, ) _build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index f80b6eff74e3f..2cadca4564a87 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -23,6 +23,7 @@ cdef struct Node: float64_t weighted_n_node_samples # Weighted number of samples at the node uint8_t missing_go_to_left # Whether features have missing values + cdef struct ParentInfo: # Structure to store information about the parent of a node # This is passed to the splitter, to provide information about the previous split @@ -52,27 +53,17 @@ cdef class Tree: cdef float64_t* value # (capacity, n_outputs, max_n_classes) array of values cdef intp_t value_stride # = n_outputs * max_n_classes - # Lower and upper boundaries of nodes, used for monotonic constraints of gradient - # boosting; they are left uninitialized otherwise - cdef bint record_node_boundaries # Whether to record the node boundaries - cdef float64_t* lower_bounds # Array of lower boundaries of nodes - cdef float64_t* upper_bounds # Array of upper boundaries of nodes - # Methods cdef intp_t _add_node(self, intp_t parent, bint is_left, bint is_leaf, intp_t feature, float64_t threshold, float64_t impurity, intp_t n_node_samples, float64_t weighted_n_node_samples, - uint8_t missing_go_to_left, - float64_t lower_bound, - float64_t upper_bound) except -1 nogil + uint8_t missing_go_to_left) except -1 nogil cdef int _resize(self, intp_t capacity) except -1 nogil cdef int _resize_c(self, intp_t capacity=*) except -1 nogil cdef cnp.ndarray _get_value_ndarray(self) cdef cnp.ndarray _get_node_ndarray(self) - cdef cnp.ndarray _get_lower_bounds_ndarray(self) - cdef cnp.ndarray _get_upper_bounds_ndarray(self) cpdef cnp.ndarray predict(self, object X) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 3aadd783e71e6..7e6946a718a81 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -268,9 +268,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): node_id = tree._add_node(parent, is_left, is_leaf, split.feature, split.threshold, parent_record.impurity, n_node_samples, weighted_n_node_samples, - split.missing_go_to_left, - parent_record.lower_bound, - parent_record.upper_bound) + split.missing_go_to_left) if node_id == INTPTR_MAX: rc = -1 @@ -628,8 +626,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): is_left, is_leaf, split.feature, split.threshold, parent_record.impurity, n_node_samples, weighted_n_node_samples, - split.missing_go_to_left, - parent_record.lower_bound, parent_record.upper_bound) + split.missing_go_to_left) if node_id == INTPTR_MAX: return -1 @@ -781,22 +778,9 @@ cdef class Tree: def value(self): return self._get_value_ndarray()[:self.node_count] - @property - def lower_bounds(self): - if not self.record_node_boundaries: - raise ValueError("Tree was not built with record_node_boundaries=True") - return self._get_lower_bounds_ndarray()[:self.node_count] - - @property - def upper_bounds(self): - if not self.record_node_boundaries: - raise ValueError("Tree was not built with record_node_boundaries=True") - return self._get_upper_bounds_ndarray()[:self.node_count] - # TODO: Convert n_classes to cython.integral memory view once # https://github.com/cython/cython/issues/5243 is fixed - def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs, - bool record_node_boundaries=False): + def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs): """Constructor.""" cdef intp_t dummy = 0 size_t_dtype = np.array(dummy).dtype @@ -822,10 +806,6 @@ cdef class Tree: self.capacity = 0 self.value = NULL self.nodes = NULL - self.lower_bounds = NULL - self.upper_bounds = NULL - - self.record_node_boundaries = record_node_boundaries def __dealloc__(self): """Destructor.""" @@ -833,8 +813,6 @@ cdef class Tree: free(self.n_classes) free(self.value) free(self.nodes) - free(self.lower_bounds) - free(self.upper_bounds) def __reduce__(self): """Reduce re-implementation, for pickling.""" @@ -850,19 +828,12 @@ cdef class Tree: d["node_count"] = self.node_count d["nodes"] = self._get_node_ndarray() d["values"] = self._get_value_ndarray() - - d["record_node_boundaries"] = self.record_node_boundaries - if self.record_node_boundaries: - d["lower_bounds"] = self._get_lower_bounds_ndarray() - d["upper_bounds"] = self._get_upper_bounds_ndarray() - return d def __setstate__(self, d): """Setstate re-implementation, for unpickling.""" self.max_depth = d["max_depth"] self.node_count = d["node_count"] - self.record_node_boundaries = d["record_node_boundaries"] if 'nodes' not in d: raise ValueError('You have loaded Tree version which ' @@ -890,15 +861,6 @@ cdef class Tree: memcpy(self.value, cnp.PyArray_DATA(value_ndarray), self.capacity * self.value_stride * sizeof(float64_t)) - if self.record_node_boundaries: - lower_bounds_ndarray = d["lower_bounds"] - upper_bounds_ndarray = d["upper_bounds"] - - memcpy(self.lower_bounds, cnp.PyArray_DATA(lower_bounds_ndarray), - self.capacity * sizeof(float64_t)) - memcpy(self.upper_bounds, cnp.PyArray_DATA(upper_bounds_ndarray), - self.capacity * sizeof(float64_t)) - cdef int _resize(self, intp_t capacity) except -1 nogil: """Resize all inner arrays to `capacity`, if `capacity` == -1, then double the size of the inner arrays. @@ -928,9 +890,6 @@ cdef class Tree: safe_realloc(&self.nodes, capacity) safe_realloc(&self.value, capacity * self.value_stride) - if self.record_node_boundaries: - safe_realloc(&self.lower_bounds, capacity) - safe_realloc(&self.upper_bounds, capacity) if capacity > self.capacity: # value memory is initialised to 0 to enable classifier argmax @@ -940,13 +899,6 @@ cdef class Tree: # node memory is initialised to 0 to ensure deterministic pickle (padding in Node struct) memset((self.nodes + self.capacity), 0, (capacity - self.capacity) * sizeof(Node)) - if self.record_node_boundaries: - # node boundaries are initialised to 0 to ensure deterministic pickle - memset((self.lower_bounds + self.capacity), 0, - (capacity - self.capacity) * sizeof(float64_t)) - memset((self.upper_bounds + self.capacity), 0, - (capacity - self.capacity) * sizeof(float64_t)) - # if capacity smaller than node_count, adjust the counter if capacity < self.node_count: self.node_count = capacity @@ -958,9 +910,7 @@ cdef class Tree: intp_t feature, float64_t threshold, float64_t impurity, intp_t n_node_samples, float64_t weighted_n_node_samples, - uint8_t missing_go_to_left, - float64_t lower_bound, - float64_t upper_bound) except -1 nogil: + uint8_t missing_go_to_left) except -1 nogil: """Add a node to the tree. The new node registers itself as the child of its parent. @@ -996,10 +946,6 @@ cdef class Tree: node.threshold = threshold node.missing_go_to_left = missing_go_to_left - if self.record_node_boundaries: - self.lower_bounds[node_id] = lower_bound - self.upper_bounds[node_id] = upper_bound - self.node_count += 1 return node_id @@ -1379,36 +1325,6 @@ cdef class Tree: raise ValueError("Can't initialize array.") return arr - cdef cnp.ndarray _get_lower_bounds_ndarray(self): - """Wraps lower bounds as a NumPy array. - - The array keeps a reference to this Tree, which manages the underlying - memory. - """ - cdef cnp.npy_intp shape[1] - shape[0] = self.node_count - cdef cnp.ndarray arr - arr = cnp.PyArray_SimpleNewFromData(1, shape, cnp.NPY_DOUBLE, self.lower_bounds) - Py_INCREF(self) - if PyArray_SetBaseObject(arr, self) < 0: - raise ValueError("Can't initialize array.") - return arr - - cdef cnp.ndarray _get_upper_bounds_ndarray(self): - """Wraps upper bounds as a NumPy array. - - The array keeps a reference to this Tree, which manages the underlying - memory. - """ - cdef cnp.npy_intp shape[1] - shape[0] = self.node_count - cdef cnp.ndarray arr - arr = cnp.PyArray_SimpleNewFromData(1, shape, cnp.NPY_DOUBLE, self.upper_bounds) - Py_INCREF(self) - if PyArray_SetBaseObject(arr, self) < 0: - raise ValueError("Can't initialize array.") - return arr - def compute_partial_dependence(self, float32_t[:, ::1] X, const intp_t[::1] target_features, float64_t[::1] out): @@ -1996,10 +1912,6 @@ cdef void _build_pruned_tree( float64_t* orig_value_ptr float64_t* new_value_ptr - bint record_node_boundaries = tree.record_node_boundaries and orig_tree.record_node_boundaries - float64_t lower_bound - float64_t upper_bound - stack[BuildPrunedRecord] prune_stack BuildPrunedRecord stack_record @@ -2028,18 +1940,10 @@ cdef void _build_pruned_tree( rc = -2 break - if record_node_boundaries: - lower_bound = orig_tree.lower_bounds[orig_node_id] - upper_bound = orig_tree.lower_bounds[orig_node_id] - else: - lower_bound = -INFINITY - upper_bound = INFINITY - new_node_id = tree._add_node( parent, is_left, is_leaf, node.feature, node.threshold, node.impurity, node.n_node_samples, - node.weighted_n_node_samples, node.missing_go_to_left, - lower_bound, upper_bound) + node.weighted_n_node_samples, node.missing_go_to_left) if new_node_id == INTPTR_MAX: rc = -1 From fbfe7005d71764251bf2ab04876e9493abc993ad Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Wed, 4 Sep 2024 13:57:31 -0400 Subject: [PATCH 11/13] revert changes --- sklearn/ensemble/_gb.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index d1eeafa00dccd..599bab46afc44 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -1511,7 +1511,6 @@ def _encode_y(self, y, sample_weight): # From here on, it is additional to the HGBT case. # expose n_classes_ attribute self.n_classes_ = n_classes - if sample_weight is None: n_trim_classes = n_classes else: From 52f2b0e5d4eb7d33e627dd1dcf3f91227550da9a Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Wed, 4 Sep 2024 23:00:55 -0400 Subject: [PATCH 12/13] resolve conversations --- sklearn/ensemble/_gb.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 599bab46afc44..5505334490aee 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -1124,8 +1124,8 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): induced. :class:`~sklearn.ensemble.HistGradientBoostingClassifier` is a much faster - variant of this algorithm for intermediate datasets (`n_samples >= 10_000`) and - supports monotonicity constraints. + variant of this algorithm already for intermediate datasets (`n_samples >= 10_000`) + and supports monotonicity constraints. Read more in the :ref:`User Guide `. @@ -1728,8 +1728,8 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): loss function. :class:`~sklearn.ensemble.HistGradientBoostingRegressor` is a much faster - variant of this algorithm for intermediate datasets (`n_samples >= 10_000`) and - supports monotonic constraints. + variant of this algorithm already for intermediate datasets (`n_samples >= 10_000`) + and supports monotonic constraints. Read more in the :ref:`User Guide `. From 0763cd486f95aff822b40bcf9b1e5d83bfac7664 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 5 Sep 2024 11:52:19 -0400 Subject: [PATCH 13/13] resolve conversations --- sklearn/ensemble/_gb.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 5505334490aee..2ba75cbc46cb0 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -1123,9 +1123,9 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): classification is a special case where only a single regression tree is induced. - :class:`~sklearn.ensemble.HistGradientBoostingClassifier` is a much faster - variant of this algorithm already for intermediate datasets (`n_samples >= 10_000`) - and supports monotonicity constraints. + :class:`~sklearn.ensemble.HistGradientBoostingClassifier` is a much faster variant + of this algorithm for intermediate and large datasets (`n_samples >= 10_000`) and + supports monotonic constraints. Read more in the :ref:`User Guide `. @@ -1727,9 +1727,9 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): each stage a regression tree is fit on the negative gradient of the given loss function. - :class:`~sklearn.ensemble.HistGradientBoostingRegressor` is a much faster - variant of this algorithm already for intermediate datasets (`n_samples >= 10_000`) - and supports monotonic constraints. + :class:`~sklearn.ensemble.HistGradientBoostingRegressor` is a much faster variant + of this algorithm for intermediate and large datasets (`n_samples >= 10_000`) and + supports monotonic constraints. Read more in the :ref:`User Guide `.