diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 54aacb3988e81..d44caaba1b23f 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -59,6 +59,31 @@ TODO: update at the time of the release. passed to the ``fit`` method of the the estimator. :pr:`26506` by `Adrin Jalali`_. + +:mod:`sklearn.ensemble` +....................... + +- |Feature| :class:`ensemble.RandomForestClassifier`, + :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier` + and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints, + useful when features are supposed to have a positive/negative effect on the target. + Missing values in the train data and multi-output targets are not supported. + :pr:`13649` by :user:`Samuel Ronsin `, + initiated by :user:`Patrick O'Reilly `. + + +:mod:`sklearn.tree` +................... + +- |Feature| :class:`tree.DecisionTreeClassifier`, :class:`tree.DecisionTreeRegressor`, + :class:`tree.ExtraTreeClassifier` and :class:`tree.ExtraTreeRegressor` now support + monotonic constraints, useful when features are supposed to have a positive/negative + effect on the target. Missing values in the train data and multi-output targets are + not supported. + :pr:`13649` by :user:`Samuel Ronsin `, initiated by + :user:`Patrick O'Reilly `. + + :mod:`sklearn.decomposition` ............................ diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index df8ecc974dd34..3d984c104f891 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1273,6 +1273,25 @@ class RandomForestClassifier(ForestClassifier): .. versionadded:: 0.22 + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + Attributes ---------- estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier` @@ -1413,6 +1432,7 @@ def __init__( class_weight=None, ccp_alpha=0.0, max_samples=None, + monotonic_cst=None, ): super().__init__( estimator=DecisionTreeClassifier(), @@ -1428,6 +1448,7 @@ def __init__( "min_impurity_decrease", "random_state", "ccp_alpha", + "monotonic_cst", ), bootstrap=bootstrap, oob_score=oob_score, @@ -1447,6 +1468,7 @@ def __init__( self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease + self.monotonic_cst = monotonic_cst self.ccp_alpha = ccp_alpha @@ -1627,6 +1649,22 @@ class RandomForestRegressor(ForestRegressor): .. versionadded:: 0.22 + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multioutput regressions (i.e. when `n_outputs_ > 1`), + - regressions trained on data with missing values. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + Attributes ---------- estimator_ : :class:`~sklearn.tree.DecisionTreeRegressor` @@ -1754,6 +1792,7 @@ def __init__( warm_start=False, ccp_alpha=0.0, max_samples=None, + monotonic_cst=None, ): super().__init__( estimator=DecisionTreeRegressor(), @@ -1769,6 +1808,7 @@ def __init__( "min_impurity_decrease", "random_state", "ccp_alpha", + "monotonic_cst", ), bootstrap=bootstrap, oob_score=oob_score, @@ -1788,6 +1828,7 @@ def __init__( self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.ccp_alpha = ccp_alpha + self.monotonic_cst = monotonic_cst class ExtraTreesClassifier(ForestClassifier): @@ -1975,6 +2016,25 @@ class ExtraTreesClassifier(ForestClassifier): .. versionadded:: 0.22 + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + Attributes ---------- estimator_ : :class:`~sklearn.tree.ExtraTreesClassifier` @@ -2104,6 +2164,7 @@ def __init__( class_weight=None, ccp_alpha=0.0, max_samples=None, + monotonic_cst=None, ): super().__init__( estimator=ExtraTreeClassifier(), @@ -2119,6 +2180,7 @@ def __init__( "min_impurity_decrease", "random_state", "ccp_alpha", + "monotonic_cst", ), bootstrap=bootstrap, oob_score=oob_score, @@ -2139,6 +2201,7 @@ def __init__( self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.ccp_alpha = ccp_alpha + self.monotonic_cst = monotonic_cst class ExtraTreesRegressor(ForestRegressor): @@ -2314,6 +2377,22 @@ class ExtraTreesRegressor(ForestRegressor): .. versionadded:: 0.22 + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonically increasing + - 0: no constraint + - -1: monotonically decreasing + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multioutput regressions (i.e. when `n_outputs_ > 1`), + - regressions trained on data with missing values. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + Attributes ---------- estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor` @@ -2426,6 +2505,7 @@ def __init__( warm_start=False, ccp_alpha=0.0, max_samples=None, + monotonic_cst=None, ): super().__init__( estimator=ExtraTreeRegressor(), @@ -2441,6 +2521,7 @@ def __init__( "min_impurity_decrease", "random_state", "ccp_alpha", + "monotonic_cst", ), bootstrap=bootstrap, oob_score=oob_score, @@ -2460,6 +2541,7 @@ def __init__( self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.ccp_alpha = ccp_alpha + self.monotonic_cst = monotonic_cst class RandomTreesEmbedding(TransformerMixin, BaseForest): @@ -2653,7 +2735,7 @@ class RandomTreesEmbedding(TransformerMixin, BaseForest): **BaseDecisionTree._parameter_constraints, "sparse_output": ["boolean"], } - for param in ("max_features", "ccp_alpha", "splitter"): + for param in ("max_features", "ccp_alpha", "splitter", "monotonic_cst"): _parameter_constraints.pop(param) criterion = "squared_error" diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 777e1a18d8396..c39e330d63536 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -137,6 +137,7 @@ class BaseGradientBoosting(BaseEnsemble, metaclass=ABCMeta): "tol": [Interval(Real, 0.0, None, closed="left")], } _parameter_constraints.pop("splitter") + _parameter_constraints.pop("monotonic_cst") @abstractmethod def __init__( diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 1721cd891c302..a9f367f0b21d3 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -122,6 +122,7 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): "max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None], "min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")], "ccp_alpha": [Interval(Real, 0.0, None, closed="left")], + "monotonic_cst": ["array-like", None], } @abstractmethod @@ -140,6 +141,7 @@ def __init__( min_impurity_decrease, class_weight=None, ccp_alpha=0.0, + monotonic_cst=None, ): self.criterion = criterion self.splitter = splitter @@ -153,6 +155,7 @@ def __init__( self.min_impurity_decrease = min_impurity_decrease self.class_weight = class_weight self.ccp_alpha = ccp_alpha + self.monotonic_cst = monotonic_cst def get_depth(self): """Return the depth of the decision tree. @@ -180,7 +183,11 @@ def get_n_leaves(self): return self.tree_.n_leaves def _support_missing_values(self, X): - return not issparse(X) and self._get_tags()["allow_nan"] + return ( + not issparse(X) + and self._get_tags()["allow_nan"] + and self.monotonic_cst is None + ) def _compute_missing_values_in_feature_mask(self, X): """Return boolean mask denoting if there are missing values for each feature. @@ -400,6 +407,45 @@ def _fit( SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS splitter = self.splitter + if self.monotonic_cst is None: + monotonic_cst = None + else: + if self.n_outputs_ > 1: + raise ValueError( + "Monotonicity constraints are not supported with multiple outputs." + ) + # Check to correct monotonicity constraint' specification, + # by applying element-wise logical conjunction + # Note: we do not cast `np.asarray(self.monotonic_cst, dtype=np.int8)` + # straight away here so as to generate error messages for invalid + # values using the original values prior to any dtype related conversion. + monotonic_cst = np.asarray(self.monotonic_cst) + if monotonic_cst.shape[0] != X.shape[1]: + raise ValueError( + "monotonic_cst has shape {} but the input data " + "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) + ) + valid_constraints = np.isin(monotonic_cst, (-1, 0, 1)) + if not np.all(valid_constraints): + unique_constaints_value = np.unique(monotonic_cst) + raise ValueError( + "monotonic_cst must be None or an array-like of -1, 0 or 1, but" + f" got {unique_constaints_value}" + ) + monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) + if is_classifier(self): + if self.n_classes_[0] > 2: + raise ValueError( + "Monotonicity constraints are not supported with multiclass " + "classification" + ) + # Binary classification trees are built by constraining probabilities + # of the *negative class* in order to make the implementation similar + # to regression trees. + # Since self.monotonic_cst encodes constraints on probabilities of the + # *positive class*, all signs must be flipped. + monotonic_cst *= -1 + if not isinstance(self.splitter, Splitter): splitter = SPLITTERS[self.splitter]( criterion, @@ -407,6 +453,7 @@ def _fit( min_samples_leaf, min_weight_leaf, random_state, + monotonic_cst, ) if is_classifier(self): @@ -798,6 +845,25 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): .. versionadded:: 0.22 + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + Attributes ---------- classes_ : ndarray of shape (n_classes,) or list of ndarray @@ -909,6 +975,7 @@ def __init__( min_impurity_decrease=0.0, class_weight=None, ccp_alpha=0.0, + monotonic_cst=None, ): super().__init__( criterion=criterion, @@ -922,6 +989,7 @@ def __init__( class_weight=class_weight, random_state=random_state, min_impurity_decrease=min_impurity_decrease, + monotonic_cst=monotonic_cst, ccp_alpha=ccp_alpha, ) @@ -1174,6 +1242,22 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): .. versionadded:: 0.22 + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multioutput regressions (i.e. when `n_outputs_ > 1`), + - regressions trained on data with missing values. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + Attributes ---------- feature_importances_ : ndarray of shape (n_features,) @@ -1272,6 +1356,7 @@ def __init__( max_leaf_nodes=None, min_impurity_decrease=0.0, ccp_alpha=0.0, + monotonic_cst=None, ): super().__init__( criterion=criterion, @@ -1285,6 +1370,7 @@ def __init__( random_state=random_state, min_impurity_decrease=min_impurity_decrease, ccp_alpha=ccp_alpha, + monotonic_cst=monotonic_cst, ) @_fit_context(prefer_skip_nested_validation=True) @@ -1499,6 +1585,25 @@ class ExtraTreeClassifier(DecisionTreeClassifier): .. versionadded:: 0.22 + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + Attributes ---------- classes_ : ndarray of shape (n_classes,) or list of ndarray @@ -1599,6 +1704,7 @@ def __init__( min_impurity_decrease=0.0, class_weight=None, ccp_alpha=0.0, + monotonic_cst=None, ): super().__init__( criterion=criterion, @@ -1613,6 +1719,7 @@ def __init__( min_impurity_decrease=min_impurity_decrease, random_state=random_state, ccp_alpha=ccp_alpha, + monotonic_cst=monotonic_cst, ) @@ -1743,6 +1850,22 @@ class ExtraTreeRegressor(DecisionTreeRegressor): .. versionadded:: 0.22 + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multioutput regressions (i.e. when `n_outputs_ > 1`), + - regressions trained on data with missing values. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + Attributes ---------- max_features_ : int @@ -1826,6 +1949,7 @@ def __init__( min_impurity_decrease=0.0, max_leaf_nodes=None, ccp_alpha=0.0, + monotonic_cst=None, ): super().__init__( criterion=criterion, @@ -1839,4 +1963,5 @@ def __init__( min_impurity_decrease=min_impurity_decrease, random_state=random_state, ccp_alpha=ccp_alpha, + monotonic_cst=monotonic_cst, ) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index a0a357a700fb4..b765d324bebb9 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -8,6 +8,7 @@ # License: BSD 3 clause # See _criterion.pyx for implementation details. +cimport numpy as cnp from ._tree cimport DTYPE_t # Type of X from ._tree cimport DOUBLE_t # Type of y, sample_weight @@ -68,6 +69,13 @@ cdef class Criterion: self, double* dest ) noexcept nogil + cdef void clip_node_value( + self, + double* dest, + double lower_bound, + double upper_bound + ) noexcept nogil + cdef double middle_value(self) noexcept nogil cdef double impurity_improvement( self, double impurity_parent, @@ -75,6 +83,20 @@ cdef class Criterion: double impurity_right ) noexcept nogil cdef double proxy_impurity_improvement(self) noexcept nogil + cdef bint check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + ) noexcept nogil + cdef inline bint _check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + double sum_left, + double sum_right, + ) noexcept nogil cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 91c347735c5e0..79f6346be239d 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -157,6 +157,16 @@ cdef class Criterion: """ pass + cdef void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil: + pass + + cdef double middle_value(self) noexcept nogil: + """Compute the middle value of a split for monotonicity constraints + + This method is implemented in ClassificationCriterion and RegressionCriterion. + """ + pass + cdef double proxy_impurity_improvement(self) noexcept nogil: """Compute a proxy of the impurity reduction. @@ -211,6 +221,36 @@ cdef class Criterion: - (self.weighted_n_left / self.weighted_n_node_samples * impurity_left))) + cdef bint check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + ) noexcept nogil: + pass + + cdef inline bint _check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + double value_left, + double value_right, + ) nogil: + cdef: + bint check_lower_bound = ( + (value_left >= lower_bound) & + (value_right >= lower_bound) + ) + bint check_upper_bound = ( + (value_left <= upper_bound) & + (value_right <= upper_bound) + ) + bint check_monotonic_cst = ( + (value_left - value_right) * monotonic_cst <= 0 + ) + return check_lower_bound & check_upper_bound & check_monotonic_cst + cdef void init_sum_missing(self): """Init sum_missing to hold sums for missing values.""" @@ -534,6 +574,47 @@ cdef class ClassificationCriterion(Criterion): memcpy(dest, &self.sum_total[k, 0], self.n_classes[k] * sizeof(double)) dest += self.max_n_classes + cdef void clip_node_value(self, double * dest, double lower_bound, double upper_bound) noexcept nogil: + """Clip the value in dest between lower_bound and upper_bound for monotonic constraints. + + Note that monotonicity constraints are only supported for: + - single-output trees and + - binary classifications. + """ + if dest[0] < lower_bound: + dest[0] = lower_bound + elif dest[0] > upper_bound: + dest[0] = upper_bound + + # Class proportions for binary classification must sum to 1. + dest[1] = 1 - dest[0] + + cdef inline double middle_value(self) noexcept nogil: + """Compute the middle value of a split for monotonicity constraints as the simple average + of the left and right children values. + + Note that monotonicity constraints are only supported for: + - single-output trees and + - binary classifications. + """ + return ( + (self.sum_left[0, 0] / (2 * self.weighted_n_left)) + + (self.sum_right[0, 0] / (2 * self.weighted_n_right)) + ) + + cdef inline bint check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + ) noexcept nogil: + """Check monotonicity constraint is satisfied at the current classification split""" + cdef: + double value_left = self.sum_left[0][0] / self.weighted_n_left + double value_right = self.sum_right[0][0] / self.weighted_n_right + + return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, value_left, value_right) + cdef class Entropy(ClassificationCriterion): r"""Cross Entropy impurity criterion. @@ -959,6 +1040,37 @@ cdef class RegressionCriterion(Criterion): for k in range(self.n_outputs): dest[k] = self.sum_total[k] / self.weighted_n_node_samples + cdef inline void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil: + """Clip the value in dest between lower_bound and upper_bound for monotonic constraints.""" + if dest[0] < lower_bound: + dest[0] = lower_bound + elif dest[0] > upper_bound: + dest[0] = upper_bound + + cdef double middle_value(self) noexcept nogil: + """Compute the middle value of a split for monotonicity constraints as the simple average + of the left and right children values. + + Monotonicity constraints are only supported for single-output trees we can safely assume + n_outputs == 1. + """ + return ( + (self.sum_left[0] / (2 * self.weighted_n_left)) + + (self.sum_right[0] / (2 * self.weighted_n_right)) + ) + + cdef bint check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + ) noexcept nogil: + """Check monotonicity constraint is satisfied at the current regression split""" + cdef: + double value_left = self.sum_left[0] / self.weighted_n_left + double value_right = self.sum_right[0] / self.weighted_n_right + + return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, value_left, value_right) cdef class MSE(RegressionCriterion): """Mean squared error impurity criterion. @@ -1289,6 +1401,31 @@ cdef class MAE(RegressionCriterion): for k in range(self.n_outputs): dest[k] = self.node_medians[k] + cdef inline double middle_value(self) noexcept nogil: + """Compute the middle value of a split for monotonicity constraints as the simple average + of the left and right children values. + + Monotonicity constraints are only supported for single-output trees we can safely assume + n_outputs == 1. + """ + return ( + ( self.left_child_ptr[0]).get_median() + + ( self.right_child_ptr[0]).get_median() + ) / 2 + + cdef inline bint check_monotonicity( + self, + cnp.int8_t monotonic_cst, + double lower_bound, + double upper_bound, + ) noexcept nogil: + """Check monotonicity constraint is satisfied at the current regression split""" + cdef: + double value_left = ( self.left_child_ptr[0]).get_median() + double value_right = ( self.right_child_ptr[0]).get_median() + + return self._check_monotonicity(monotonic_cst, lower_bound, upper_bound, value_left, value_right) + cdef double node_impurity(self) noexcept nogil: """Evaluate the impurity of the current node. diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index acc67a7315add..2547e14b324df 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -8,6 +8,7 @@ # License: BSD 3 clause # See _splitter.pyx for details. +cimport numpy as cnp from ._criterion cimport Criterion @@ -27,6 +28,8 @@ cdef struct SplitRecord: double improvement # Impurity improvement given parent node. double impurity_left # Impurity of the left split. double impurity_right # Impurity of the right split. + double lower_bound # Lower bound on value of both children for monotonicity + double upper_bound # Upper bound on value of both children for monotonicity unsigned char missing_go_to_left # Controls if missing values go to the left node. SIZE_t n_missing # Number of missing values for the feature being split on @@ -57,6 +60,13 @@ cdef class Splitter: cdef SIZE_t end # End position for the current node cdef const DOUBLE_t[:, ::1] y + # Monotonicity constraints for each feature. + # The encoding is as follows: + # -1: monotonic decrease + # 0: no constraint + # +1: monotonic increase + cdef const cnp.int8_t[:] monotonic_cst + cdef bint with_monotonic_cst cdef const DOUBLE_t[:] sample_weight # The samples vector `samples` is maintained by the Splitter object such @@ -95,9 +105,13 @@ cdef class Splitter: self, double impurity, # Impurity of the node SplitRecord* split, - SIZE_t* n_constant_features + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound, ) except -1 nogil cdef void node_value(self, double* dest) noexcept nogil + cdef void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil + cdef double node_impurity(self) noexcept nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 7e60f0023d2a2..edbfff13cd941 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -11,6 +11,8 @@ # # License: BSD 3 clause +cimport numpy as cnp + from ._criterion cimport Criterion from libc.stdlib cimport qsort @@ -53,9 +55,15 @@ cdef class Splitter: sparse and dense data, one split at a time. """ - def __cinit__(self, Criterion criterion, SIZE_t max_features, - SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state): + def __cinit__( + self, + Criterion criterion, + SIZE_t max_features, + SIZE_t min_samples_leaf, + double min_weight_leaf, + object random_state, + const cnp.int8_t[:] monotonic_cst, + ): """ Parameters ---------- @@ -77,6 +85,10 @@ cdef class Splitter: random_state : object The user inputted random state to be used for pseudo-randomness + + monotonic_cst : const cnp.int8_t[:] + Monotonicity constraints + """ self.criterion = criterion @@ -88,6 +100,8 @@ cdef class Splitter: self.min_samples_leaf = min_samples_leaf self.min_weight_leaf = min_weight_leaf self.random_state = random_state + self.monotonic_cst = monotonic_cst + self.with_monotonic_cst = monotonic_cst is not None def __getstate__(self): return {} @@ -100,7 +114,8 @@ cdef class Splitter: self.max_features, self.min_samples_leaf, self.min_weight_leaf, - self.random_state), self.__getstate__()) + self.random_state, + self.monotonic_cst), self.__getstate__()) cdef int init( self, @@ -208,8 +223,15 @@ cdef class Splitter: weighted_n_node_samples[0] = self.criterion.weighted_n_node_samples return 0 - cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) except -1 nogil: + cdef int node_split( + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound, + ) nogil except -1: + """Find the best split on node samples[start:end]. This is a placeholder method. The majority of computation will be done @@ -225,6 +247,11 @@ cdef class Splitter: self.criterion.node_value(dest) + cdef inline void clip_node_value(self, double* dest, double lower_bound, double upper_bound) noexcept nogil: + """Clip the value in dest between lower_bound and upper_bound for monotonic constraints.""" + + self.criterion.clip_node_value(dest, lower_bound, upper_bound) + cdef double node_impurity(self) noexcept nogil: """Return the impurity of the current node.""" @@ -264,6 +291,10 @@ cdef inline int node_split_best( double impurity, SplitRecord* split, SIZE_t* n_constant_features, + bint with_monotonic_cst, + const cnp.int8_t[:] monotonic_cst, + double lower_bound, + double upper_bound, ) except -1 nogil: """Find the best split on node samples[start:end] @@ -416,6 +447,18 @@ cdef inline int node_split_best( current_split.pos = p criterion.update(current_split.pos) + # Reject if monotonicity constraints are not satisfied + if ( + with_monotonic_cst and + monotonic_cst[current_split.feature] != 0 and + not criterion.check_monotonicity( + monotonic_cst[current_split.feature], + lower_bound, + upper_bound, + ) + ): + continue + # Reject if min_weight_leaf is not satisfied if ((criterion.weighted_n_left < min_weight_leaf) or (criterion.weighted_n_right < min_weight_leaf)): @@ -628,7 +671,11 @@ cdef inline int node_split_random( Criterion criterion, double impurity, SplitRecord* split, - SIZE_t* n_constant_features + SIZE_t* n_constant_features, + bint with_monotonic_cst, + const cnp.int8_t[:] monotonic_cst, + double lower_bound, + double upper_bound, ) except -1 nogil: """Find the best random split on node samples[start:end] @@ -756,6 +803,18 @@ cdef inline int node_split_random( (criterion.weighted_n_right < min_weight_leaf)): continue + # Reject if monotonicity constraints are not satisfied + if ( + with_monotonic_cst and + monotonic_cst[current_split.feature] != 0 and + not criterion.check_monotonicity( + monotonic_cst[current_split.feature], + lower_bound, + upper_bound, + ) + ): + continue + current_proxy_improvement = criterion.proxy_impurity_improvement() if current_proxy_improvement > best_proxy_improvement: @@ -1441,8 +1500,14 @@ cdef class BestSplitter(Splitter): X, self.samples, self.feature_values, missing_values_in_feature_mask ) - cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) except -1 nogil: + cdef int node_split( + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) except -1 nogil: return node_split_best( self, self.partitioner, @@ -1450,6 +1515,10 @@ cdef class BestSplitter(Splitter): impurity, split, n_constant_features, + self.with_monotonic_cst, + self.monotonic_cst, + lower_bound, + upper_bound ) cdef class BestSparseSplitter(Splitter): @@ -1467,8 +1536,14 @@ cdef class BestSparseSplitter(Splitter): X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask ) - cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) except -1 nogil: + cdef int node_split( + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) except -1 nogil: return node_split_best( self, self.partitioner, @@ -1476,6 +1551,10 @@ cdef class BestSparseSplitter(Splitter): impurity, split, n_constant_features, + self.with_monotonic_cst, + self.monotonic_cst, + lower_bound, + upper_bound ) cdef class RandomSplitter(Splitter): @@ -1493,8 +1572,14 @@ cdef class RandomSplitter(Splitter): X, self.samples, self.feature_values, missing_values_in_feature_mask ) - cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) except -1 nogil: + cdef int node_split( + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) except -1 nogil: return node_split_random( self, self.partitioner, @@ -1502,6 +1587,10 @@ cdef class RandomSplitter(Splitter): impurity, split, n_constant_features, + self.with_monotonic_cst, + self.monotonic_cst, + lower_bound, + upper_bound ) cdef class RandomSparseSplitter(Splitter): @@ -1518,9 +1607,14 @@ cdef class RandomSparseSplitter(Splitter): self.partitioner = SparsePartitioner( X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask ) - - cdef int node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) except -1 nogil: + cdef int node_split( + self, + double impurity, + SplitRecord* split, + SIZE_t* n_constant_features, + double lower_bound, + double upper_bound + ) except -1 nogil: return node_split_random( self, self.partitioner, @@ -1528,4 +1622,8 @@ cdef class RandomSparseSplitter(Splitter): impurity, split, n_constant_features, + self.with_monotonic_cst, + self.monotonic_cst, + lower_bound, + upper_bound ) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index e7a0ab2f2966d..e0aeec26bfef4 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -148,6 +148,8 @@ cdef struct StackRecord: bint is_left double impurity SIZE_t n_constant_features + double lower_bound + double upper_bound cdef class DepthFirstTreeBuilder(TreeBuilder): """Build a decision tree in depth-first fashion.""" @@ -207,6 +209,9 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef SIZE_t node_id cdef double impurity = INFINITY + cdef double lower_bound + cdef double upper_bound + cdef double middle_value cdef SIZE_t n_constant_features cdef bint is_leaf cdef bint first = 1 @@ -225,7 +230,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): "parent": _TREE_UNDEFINED, "is_left": 0, "impurity": INFINITY, - "n_constant_features": 0}) + "n_constant_features": 0, + "lower_bound": -INFINITY, + "upper_bound": INFINITY, + }) while not builder_stack.empty(): stack_record = builder_stack.top() @@ -238,6 +246,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): is_left = stack_record.is_left impurity = stack_record.impurity n_constant_features = stack_record.n_constant_features + lower_bound = stack_record.lower_bound + upper_bound = stack_record.upper_bound n_node_samples = end - start splitter.node_reset(start, end, &weighted_n_node_samples) @@ -255,7 +265,13 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): is_leaf = is_leaf or impurity <= EPSILON if not is_leaf: - splitter.node_split(impurity, &split, &n_constant_features) + splitter.node_split( + impurity, + &split, + &n_constant_features, + lower_bound, + upper_bound + ) # If EPSILON=0 in the below comparison, float precision # issues stop splitting, producing trees that are # dissimilar to v0.18 @@ -275,8 +291,42 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # Store value for all nodes, to facilitate tree/model # inspection and interpretation splitter.node_value(tree.value + node_id * tree.value_stride) + if splitter.with_monotonic_cst: + splitter.clip_node_value(tree.value + node_id * tree.value_stride, lower_bound, upper_bound) if not is_leaf: + if ( + not splitter.with_monotonic_cst or + splitter.monotonic_cst[split.feature] == 0 + ): + # Split on a feature with no monotonicity constraint + + # Current bounds must always be propagated to both children. + # If a monotonic constraint is active, bounds are used in + # node value clipping. + left_child_min = right_child_min = lower_bound + left_child_max = right_child_max = upper_bound + elif splitter.monotonic_cst[split.feature] == 1: + # Split on a feature with monotonic increase constraint + left_child_min = lower_bound + right_child_max = upper_bound + + # Lower bound for right child and upper bound for left child + # are set to the same value. + middle_value = splitter.criterion.middle_value() + right_child_min = middle_value + left_child_max = middle_value + else: # i.e. splitter.monotonic_cst[split.feature] == -1 + # Split on a feature with monotonic decrease constraint + right_child_min = lower_bound + left_child_max = upper_bound + + # Lower bound for left child and upper bound for right child + # are set to the same value. + middle_value = splitter.criterion.middle_value() + left_child_min = middle_value + right_child_max = middle_value + # Push right child on stack builder_stack.push({ "start": split.pos, @@ -285,7 +335,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): "parent": node_id, "is_left": 0, "impurity": split.impurity_right, - "n_constant_features": n_constant_features}) + "n_constant_features": n_constant_features, + "lower_bound": right_child_min, + "upper_bound": right_child_max, + }) # Push left child on stack builder_stack.push({ @@ -295,7 +348,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): "parent": node_id, "is_left": 1, "impurity": split.impurity_left, - "n_constant_features": n_constant_features}) + "n_constant_features": n_constant_features, + "lower_bound": left_child_min, + "upper_bound": left_child_max, + }) if depth > max_depth_seen: max_depth_seen = depth @@ -324,6 +380,9 @@ cdef struct FrontierRecord: double impurity_left double impurity_right double improvement + double lower_bound + double upper_bound + double middle_value cdef inline bool _compare_records( const FrontierRecord& left, @@ -384,6 +443,10 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef FrontierRecord record cdef FrontierRecord split_node_left cdef FrontierRecord split_node_right + cdef double left_child_min + cdef double left_child_max + cdef double right_child_min + cdef double right_child_max cdef SIZE_t n_node_samples = splitter.n_samples cdef SIZE_t max_split_nodes = max_leaf_nodes - 1 @@ -398,9 +461,20 @@ cdef class BestFirstTreeBuilder(TreeBuilder): with nogil: # add root to frontier - rc = self._add_split_node(splitter, tree, 0, n_node_samples, - INFINITY, IS_FIRST, IS_LEFT, NULL, 0, - &split_node_left) + rc = self._add_split_node( + splitter=splitter, + tree=tree, + start=0, + end=n_node_samples, + impurity=INFINITY, + is_first=IS_FIRST, + is_left=IS_LEFT, + parent=NULL, + depth=0, + lower_bound=-INFINITY, + upper_bound=INFINITY, + res=&split_node_left, + ) if rc >= 0: _add_to_frontier(split_node_left, frontier) @@ -422,16 +496,54 @@ cdef class BestFirstTreeBuilder(TreeBuilder): else: # Node is expandable + if ( + not splitter.with_monotonic_cst or + splitter.monotonic_cst[node.feature] == 0 + ): + # Split on a feature with no monotonicity constraint + + # Current bounds must always be propagated to both children. + # If a monotonic constraint is active, bounds are used in + # node value clipping. + left_child_min = right_child_min = record.lower_bound + left_child_max = right_child_max = record.upper_bound + elif splitter.monotonic_cst[node.feature] == 1: + # Split on a feature with monotonic increase constraint + left_child_min = record.lower_bound + right_child_max = record.upper_bound + + # Lower bound for right child and upper bound for left child + # are set to the same value. + right_child_min = record.middle_value + left_child_max = record.middle_value + else: # i.e. splitter.monotonic_cst[split.feature] == -1 + # Split on a feature with monotonic decrease constraint + right_child_min = record.lower_bound + left_child_max = record.upper_bound + + # Lower bound for left child and upper bound for right child + # are set to the same value. + left_child_min = record.middle_value + right_child_max = record.middle_value + # Decrement number of split nodes available max_split_nodes -= 1 # Compute left split node - rc = self._add_split_node(splitter, tree, - record.start, record.pos, - record.impurity_left, - IS_NOT_FIRST, IS_LEFT, node, - record.depth + 1, - &split_node_left) + rc = self._add_split_node( + splitter=splitter, + tree=tree, + start=record.start, + end=record.pos, + impurity=record.impurity_left, + is_first=IS_NOT_FIRST, + is_left=IS_LEFT, + parent=node, + depth=record.depth + 1, + lower_bound=left_child_min, + upper_bound=left_child_max, + res=&split_node_left, + ) if rc == -1: break @@ -439,12 +551,20 @@ cdef class BestFirstTreeBuilder(TreeBuilder): node = &tree.nodes[record.node_id] # Compute right split node - rc = self._add_split_node(splitter, tree, record.pos, - record.end, - record.impurity_right, - IS_NOT_FIRST, IS_NOT_LEFT, node, - record.depth + 1, - &split_node_right) + rc = self._add_split_node( + splitter=splitter, + tree=tree, + start=record.pos, + end=record.end, + impurity=record.impurity_right, + is_first=IS_NOT_FIRST, + is_left=IS_NOT_LEFT, + parent=node, + depth=record.depth + 1, + lower_bound=right_child_min, + upper_bound=right_child_max, + res=&split_node_right, + ) if rc == -1: break @@ -464,11 +584,21 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if rc == -1: raise MemoryError() - cdef inline int _add_split_node(self, Splitter splitter, Tree tree, - SIZE_t start, SIZE_t end, double impurity, - bint is_first, bint is_left, Node* parent, - SIZE_t depth, - FrontierRecord* res) except -1 nogil: + cdef inline int _add_split_node( + self, + Splitter splitter, + Tree tree, + SIZE_t start, + SIZE_t end, + double impurity, + bint is_first, + bint is_left, + Node* parent, + SIZE_t depth, + double lower_bound, + double upper_bound, + FrontierRecord* res + ) nogil except -1: """Adds node w/ partition ``[start, end)`` to the frontier. """ cdef SplitRecord split cdef SIZE_t node_id @@ -492,7 +622,13 @@ cdef class BestFirstTreeBuilder(TreeBuilder): ) if not is_leaf: - splitter.node_split(impurity, &split, &n_constant_features) + splitter.node_split( + impurity, + &split, + &n_constant_features, + lower_bound, + upper_bound + ) # If EPSILON=0 in the below comparison, float precision issues stop # splitting early, producing trees that are dissimilar to v0.18 is_leaf = (is_leaf or split.pos >= end or @@ -510,12 +646,17 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # compute values also for split nodes (might become leafs later). splitter.node_value(tree.value + node_id * tree.value_stride) + if splitter.with_monotonic_cst: + splitter.clip_node_value(tree.value + node_id * tree.value_stride, lower_bound, upper_bound) res.node_id = node_id res.start = start res.end = end res.depth = depth res.impurity = impurity + res.lower_bound = lower_bound + res.upper_bound = upper_bound + res.middle_value = splitter.criterion.middle_value() if not is_leaf: # is split node diff --git a/sklearn/tree/tests/test_monotonic_tree.py b/sklearn/tree/tests/test_monotonic_tree.py new file mode 100644 index 0000000000000..462ac7305d7c2 --- /dev/null +++ b/sklearn/tree/tests/test_monotonic_tree.py @@ -0,0 +1,491 @@ +import numpy as np +import pytest +import scipy.sparse + +from sklearn.datasets import make_classification, make_regression +from sklearn.ensemble import ( + ExtraTreesClassifier, + ExtraTreesRegressor, + RandomForestClassifier, + RandomForestRegressor, +) +from sklearn.tree import ( + DecisionTreeClassifier, + DecisionTreeRegressor, + ExtraTreeClassifier, + ExtraTreeRegressor, +) + +TREE_CLASSIFIER_CLASSES = [DecisionTreeClassifier, ExtraTreeClassifier] +TREE_REGRESSOR_CLASSES = [DecisionTreeRegressor, ExtraTreeRegressor] +TREE_BASED_CLASSIFIER_CLASSES = TREE_CLASSIFIER_CLASSES + [ + RandomForestClassifier, + ExtraTreesClassifier, +] +TREE_BASED_REGRESSOR_CLASSES = TREE_REGRESSOR_CLASSES + [ + RandomForestRegressor, + ExtraTreesRegressor, +] + + +@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) +@pytest.mark.parametrize("depth_first_builder", (True, False)) +@pytest.mark.parametrize("sparse_splitter", (True, False)) +def test_monotonic_constraints_classifications( + TreeClassifier, depth_first_builder, sparse_splitter, global_random_seed +): + n_samples = 1000 + n_samples_train = 900 + X, y = make_classification( + n_samples=n_samples, + n_classes=2, + n_features=5, + n_informative=5, + n_redundant=0, + random_state=global_random_seed, + ) + X_train, y_train = X[:n_samples_train], y[:n_samples_train] + X_test, _ = X[n_samples_train:], y[n_samples_train:] + + X_test_0incr, X_test_0decr = np.copy(X_test), np.copy(X_test) + X_test_1incr, X_test_1decr = np.copy(X_test), np.copy(X_test) + X_test_0incr[:, 0] += 10 + X_test_0decr[:, 0] -= 10 + X_test_1incr[:, 1] += 10 + X_test_1decr[:, 1] -= 10 + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = 1 + monotonic_cst[1] = -1 + + if depth_first_builder: + est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst) + else: + est = TreeClassifier( + max_depth=None, + monotonic_cst=monotonic_cst, + max_leaf_nodes=n_samples_train, + ) + if hasattr(est, "random_state"): + est.set_params(**{"random_state": global_random_seed}) + if hasattr(est, "n_estimators"): + est.set_params(**{"n_estimators": 5}) + if sparse_splitter: + X_train = scipy.sparse.csc_matrix(X_train) + est.fit(X_train, y_train) + y = est.predict_proba(X_test)[:, 1] + + # Monotonic increase constraint, it applies to the positive class + assert np.all(est.predict_proba(X_test_0incr)[:, 1] >= y) + assert np.all(est.predict_proba(X_test_0decr)[:, 1] <= y) + + # Monotonic decrease constraint, it applies to the positive class + assert np.all(est.predict_proba(X_test_1incr)[:, 1] <= y) + assert np.all(est.predict_proba(X_test_1decr)[:, 1] >= y) + + +@pytest.mark.parametrize("TreeRegressor", TREE_BASED_REGRESSOR_CLASSES) +@pytest.mark.parametrize("depth_first_builder", (True, False)) +@pytest.mark.parametrize("sparse_splitter", (True, False)) +@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error")) +def test_monotonic_constraints_regressions( + TreeRegressor, depth_first_builder, sparse_splitter, criterion, global_random_seed +): + n_samples = 1000 + n_samples_train = 900 + # Build a regression task using 5 informative features + X, y = make_regression( + n_samples=n_samples, + n_features=5, + n_informative=5, + random_state=global_random_seed, + ) + train = np.arange(n_samples_train) + test = np.arange(n_samples_train, n_samples) + X_train = X[train] + y_train = y[train] + X_test = np.copy(X[test]) + X_test_incr = np.copy(X_test) + X_test_decr = np.copy(X_test) + X_test_incr[:, 0] += 10 + X_test_decr[:, 1] += 10 + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = 1 + monotonic_cst[1] = -1 + + if depth_first_builder: + est = TreeRegressor( + max_depth=None, + monotonic_cst=monotonic_cst, + criterion=criterion, + ) + else: + est = TreeRegressor( + max_depth=8, + monotonic_cst=monotonic_cst, + criterion=criterion, + max_leaf_nodes=n_samples_train, + ) + if hasattr(est, "random_state"): + est.set_params(random_state=global_random_seed) + if hasattr(est, "n_estimators"): + est.set_params(**{"n_estimators": 5}) + if sparse_splitter: + X_train = scipy.sparse.csc_matrix(X_train) + est.fit(X_train, y_train) + y = est.predict(X_test) + # Monotonic increase constraint + y_incr = est.predict(X_test_incr) + # y_incr should always be greater than y + assert np.all(y_incr >= y) + + # Monotonic decrease constraint + y_decr = est.predict(X_test_decr) + # y_decr should always be lower than y + assert np.all(y_decr <= y) + + +@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) +def test_multiclass_raises(TreeClassifier): + X, y = make_classification( + n_samples=100, n_features=5, n_classes=3, n_informative=3, random_state=0 + ) + y[0] = 0 + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = -1 + monotonic_cst[1] = 1 + est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst, random_state=0) + + msg = "Monotonicity constraints are not supported with multiclass classification" + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + + +@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) +def test_multiple_output_raises(TreeClassifier): + X = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] + y = [[1, 0, 1, 0, 1], [1, 0, 1, 0, 1]] + + est = TreeClassifier( + max_depth=None, monotonic_cst=np.array([-1, 1]), random_state=0 + ) + msg = "Monotonicity constraints are not supported with multiple output" + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + + +@pytest.mark.parametrize( + "DecisionTreeEstimator", [DecisionTreeClassifier, DecisionTreeRegressor] +) +def test_missing_values_raises(DecisionTreeEstimator): + X, y = make_classification( + n_samples=100, n_features=5, n_classes=2, n_informative=3, random_state=0 + ) + X[0, 0] = np.nan + monotonic_cst = np.zeros(X.shape[1]) + monotonic_cst[0] = 1 + est = DecisionTreeEstimator( + max_depth=None, monotonic_cst=monotonic_cst, random_state=0 + ) + + msg = "Input X contains NaN" + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + + +@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES) +def test_bad_monotonic_cst_raises(TreeClassifier): + X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] + y = [1, 0, 1, 0, 1] + + msg = "monotonic_cst has shape 3 but the input data X has 2 features." + est = TreeClassifier( + max_depth=None, monotonic_cst=np.array([-1, 1, 0]), random_state=0 + ) + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + + msg = "monotonic_cst must be None or an array-like of -1, 0 or 1." + est = TreeClassifier( + max_depth=None, monotonic_cst=np.array([-2, 2]), random_state=0 + ) + with pytest.raises(ValueError, match=msg): + est.fit(X, y) + + est = TreeClassifier( + max_depth=None, monotonic_cst=np.array([-1, 0.8]), random_state=0 + ) + with pytest.raises(ValueError, match=msg + "(.*)0.8]"): + est.fit(X, y) + + +def assert_1d_reg_tree_children_monotonic_bounded(tree_, monotonic_sign): + values = tree_.value + for i in range(tree_.node_count): + if tree_.children_left[i] > i and tree_.children_right[i] > i: + # Check monotonicity on children + i_left = tree_.children_left[i] + i_right = tree_.children_right[i] + if monotonic_sign == 1: + assert values[i_left] <= values[i_right] + elif monotonic_sign == -1: + assert values[i_left] >= values[i_right] + val_middle = (values[i_left] + values[i_right]) / 2 + # Check bounds on grand-children, filtering out leaf nodes + if tree_.feature[i_left] >= 0: + i_left_right = tree_.children_right[i_left] + if monotonic_sign == 1: + assert values[i_left_right] <= val_middle + elif monotonic_sign == -1: + assert values[i_left_right] >= val_middle + if tree_.feature[i_right] >= 0: + i_right_left = tree_.children_left[i_right] + if monotonic_sign == 1: + assert val_middle <= values[i_right_left] + elif monotonic_sign == -1: + assert val_middle >= values[i_right_left] + + +def test_assert_1d_reg_tree_children_monotonic_bounded(): + X = np.linspace(-1, 1, 7).reshape(-1, 1) + y = np.sin(2 * np.pi * X.ravel()) + + reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y) + + with pytest.raises(AssertionError): + assert_1d_reg_tree_children_monotonic_bounded(reg.tree_, 1) + + with pytest.raises(AssertionError): + assert_1d_reg_tree_children_monotonic_bounded(reg.tree_, -1) + + +def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps): + X_grid = np.linspace(min_x, max_x, n_steps).reshape(-1, 1) + y_pred_grid = clf.predict(X_grid) + if monotonic_sign == 1: + assert (np.diff(y_pred_grid) >= 0.0).all() + elif monotonic_sign == -1: + assert (np.diff(y_pred_grid) <= 0.0).all() + + +@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) +def test_1d_opposite_monotonicity_cst_data(TreeRegressor): + # Check that positive monotonic data with negative monotonic constraint + # yield constant predictions, equal to the average of target values + X = np.linspace(-2, 2, 10).reshape(-1, 1) + y = X.ravel() + clf = TreeRegressor(monotonic_cst=[-1]) + clf.fit(X, y) + assert clf.tree_.node_count == 1 + assert clf.tree_.value[0] == 0.0 + + # Swap monotonicity + clf = TreeRegressor(monotonic_cst=[1]) + clf.fit(X, -y) + assert clf.tree_.node_count == 1 + assert clf.tree_.value[0] == 0.0 + + +@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) +@pytest.mark.parametrize("monotonic_sign", (-1, 1)) +@pytest.mark.parametrize("depth_first_builder", (True, False)) +@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error")) +def test_1d_tree_nodes_values( + TreeRegressor, monotonic_sign, depth_first_builder, criterion, global_random_seed +): + # Adaptation from test_nodes_values in test_monotonic_constraints.py + # in sklearn.ensemble._hist_gradient_boosting + # Build a single tree with only one feature, and make sure the node + # values respect the monotonicity constraints. + + # Considering the following tree with a monotonic +1 constraint, we + # should have: + # + # root + # / \ + # a b + # / \ / \ + # c d e f + # + # a <= root <= b + # c <= d <= (a + b) / 2 <= e <= f + + rng = np.random.RandomState(global_random_seed) + n_samples = 1000 + n_features = 1 + X = rng.rand(n_samples, n_features) + y = rng.rand(n_samples) + + if depth_first_builder: + # No max_leaf_nodes, default depth first tree builder + clf = TreeRegressor( + monotonic_cst=[monotonic_sign], + criterion=criterion, + random_state=global_random_seed, + ) + else: + # max_leaf_nodes triggers best first tree builder + clf = TreeRegressor( + monotonic_cst=[monotonic_sign], + max_leaf_nodes=n_samples, + criterion=criterion, + random_state=global_random_seed, + ) + clf.fit(X, y) + + assert_1d_reg_tree_children_monotonic_bounded(clf.tree_, monotonic_sign) + assert_1d_reg_monotonic(clf, monotonic_sign, np.min(X), np.max(X), 100) + + +def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst): + upper_bound = np.full(tree_.node_count, np.inf) + lower_bound = np.full(tree_.node_count, -np.inf) + for i in range(tree_.node_count): + feature = tree_.feature[i] + node_value = tree_.value[i][0][0] # unpack value from nx1x1 array + # While building the tree, the computed middle value is slightly + # different from the average of the siblings values, because + # sum_right / weighted_n_right + # is slightly different from the value of the right sibling. + # This can cause a discrepancy up to numerical noise when clipping, + # which is resolved by comparing with some loss of precision. + assert np.float32(node_value) <= np.float32(upper_bound[i]) + assert np.float32(node_value) >= np.float32(lower_bound[i]) + + if feature < 0: + # Leaf: nothing to do + continue + + # Split node: check and update bounds for the children. + i_left = tree_.children_left[i] + i_right = tree_.children_right[i] + # unpack value from nx1x1 array + middle_value = (tree_.value[i_left][0][0] + tree_.value[i_right][0][0]) / 2 + + if monotonic_cst[feature] == 0: + # Feature without monotonicity constraint: propagate bounds + # down the tree to both children. + # Otherwise, with 2 features and a monotonic increase constraint + # (encoded by +1) on feature 0, the following tree can be accepted, + # although it does not respect the monotonic increase constraint: + # + # X[0] <= 0 + # value = 100 + # / \ + # X[0] <= -1 X[1] <= 0 + # value = 50 value = 150 + # / \ / \ + # leaf leaf leaf leaf + # value = 25 value = 75 value = 50 value = 250 + + lower_bound[i_left] = lower_bound[i] + upper_bound[i_left] = upper_bound[i] + lower_bound[i_right] = lower_bound[i] + upper_bound[i_right] = upper_bound[i] + + elif monotonic_cst[feature] == 1: + # Feature with constraint: check monotonicity + assert tree_.value[i_left] <= tree_.value[i_right] + + # Propagate bounds down the tree to both children. + lower_bound[i_left] = lower_bound[i] + upper_bound[i_left] = middle_value + lower_bound[i_right] = middle_value + upper_bound[i_right] = upper_bound[i] + + elif monotonic_cst[feature] == -1: + # Feature with constraint: check monotonicity + assert tree_.value[i_left] >= tree_.value[i_right] + + # Update and propagate bounds down the tree to both children. + lower_bound[i_left] = middle_value + upper_bound[i_left] = upper_bound[i] + lower_bound[i_right] = lower_bound[i] + upper_bound[i_right] = middle_value + + else: # pragma: no cover + raise ValueError(f"monotonic_cst[{feature}]={monotonic_cst[feature]}") + + +def test_assert_nd_reg_tree_children_monotonic_bounded(): + # Check that assert_nd_reg_tree_children_monotonic_bounded can detect + # non-monotonic tree predictions. + X = np.linspace(0, 2 * np.pi, 30).reshape(-1, 1) + y = np.sin(X).ravel() + reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y) + + with pytest.raises(AssertionError): + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [1]) + + with pytest.raises(AssertionError): + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [-1]) + + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [0]) + + # Check that assert_nd_reg_tree_children_monotonic_bounded raises + # when the data (and therefore the model) is naturally monotonic in the + # opposite direction. + X = np.linspace(-5, 5, 5).reshape(-1, 1) + y = X.ravel() ** 3 + reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y) + + with pytest.raises(AssertionError): + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [-1]) + + # For completeness, check that the converse holds when swapping the sign. + reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, -y) + + with pytest.raises(AssertionError): + assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [1]) + + +@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES) +@pytest.mark.parametrize("monotonic_sign", (-1, 1)) +@pytest.mark.parametrize("depth_first_builder", (True, False)) +@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error")) +def test_nd_tree_nodes_values( + TreeRegressor, monotonic_sign, depth_first_builder, criterion, global_random_seed +): + # Build tree with several features, and make sure the nodes + # values respect the monotonicity constraints. + + # Considering the following tree with a monotonic increase constraint on X[0], + # we should have: + # + # root + # X[0]<=t + # / \ + # a b + # X[0]<=u X[1]<=v + # / \ / \ + # c d e f + # + # i) a <= root <= b + # ii) c <= a <= d <= (a+b)/2 + # iii) (a+b)/2 <= min(e,f) + # For iii) we check that each node value is within the proper lower and + # upper bounds. + + rng = np.random.RandomState(global_random_seed) + n_samples = 1000 + n_features = 2 + monotonic_cst = [monotonic_sign, 0] + X = rng.rand(n_samples, n_features) + y = rng.rand(n_samples) + + if depth_first_builder: + # No max_leaf_nodes, default depth first tree builder + clf = TreeRegressor( + monotonic_cst=monotonic_cst, + criterion=criterion, + random_state=global_random_seed, + ) + else: + # max_leaf_nodes triggers best first tree builder + clf = TreeRegressor( + monotonic_cst=monotonic_cst, + max_leaf_nodes=n_samples, + criterion=criterion, + random_state=global_random_seed, + ) + clf.fit(X, y) + assert_nd_reg_tree_children_monotonic_bounded(clf.tree_, monotonic_cst) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index cadbe2c9f702e..034ee5fc39917 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2359,7 +2359,7 @@ def test_splitter_serializable(Splitter): n_outputs, n_classes = 2, np.array([3, 2], dtype=np.intp) criterion = CRITERIA_CLF["gini"](n_outputs, n_classes) - splitter = Splitter(criterion, max_features, 5, 0.5, rng) + splitter = Splitter(criterion, max_features, 5, 0.5, rng, monotonic_cst=None) splitter_serialize = pickle.dumps(splitter) splitter_back = pickle.loads(splitter_serialize)