Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 83dba40

Browse files
committed
Use fast way when there's no constraints
1 parent 57fde2a commit 83dba40

File tree

1 file changed

+30
-25
lines changed
  • sklearn/ensemble/_hist_gradient_boosting

1 file changed

+30
-25
lines changed

sklearn/ensemble/_hist_gradient_boosting/grower.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -202,22 +202,26 @@ def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None,
202202
has_missing_values = np.asarray(has_missing_values, dtype=np.uint8)
203203

204204
if monotonic_cst is None:
205+
self.with_monotonic_cst = False
205206
monotonic_cst = np.full(shape=X_binned.shape[1],
206207
fill_value=MonotonicConstraint.NO_CST,
207208
dtype=np.int8)
208209
else:
210+
self.with_monotonic_cst = True
209211
monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8)
210212

211-
if monotonic_cst.shape[0] != X_binned.shape[1]:
212-
raise ValueError(
213-
"monotonic_cst has shape {} but the input data "
214-
"X has {} features.".format(
215-
monotonic_cst.shape[0], X_binned.shape[1]
213+
if monotonic_cst.shape[0] != X_binned.shape[1]:
214+
raise ValueError(
215+
"monotonic_cst has shape {} but the input data "
216+
"X has {} features.".format(
217+
monotonic_cst.shape[0], X_binned.shape[1]
218+
)
216219
)
217-
)
218-
if np.any(monotonic_cst < -1) or np.any(monotonic_cst > 1):
219-
raise ValueError(
220-
"monotonic_cst must be None or an array-like of -1, 0 or 1.")
220+
if np.any(monotonic_cst < -1) or np.any(monotonic_cst > 1):
221+
raise ValueError(
222+
"monotonic_cst must be None or an array-like of "
223+
"-1, 0 or 1."
224+
)
221225

222226
hessians_are_constant = hessians.shape[0] == 1
223227
self.histogram_builder = HistogramBuilder(
@@ -431,23 +435,24 @@ def split_next(self):
431435
if right_child_node.n_samples < self.min_samples_leaf * 2:
432436
self._finalize_leaf(right_child_node)
433437

434-
# Set value bounds for respecting monotonic constraints
435-
# See test_nodes_values() for details
436-
if (self.monotonic_cst[node.split_info.feature_idx] ==
437-
MonotonicConstraint.NO_CST):
438-
lower_left = lower_right = node.children_lower_bound
439-
upper_left = upper_right = node.children_upper_bound
440-
else:
441-
middle = (left_child_node.value + right_child_node.value) / 2
438+
if self.with_monotonic_cst:
439+
# Set value bounds for respecting monotonic constraints
440+
# See test_nodes_values() for details
442441
if (self.monotonic_cst[node.split_info.feature_idx] ==
443-
MonotonicConstraint.POS):
444-
lower_left, upper_left = node.children_lower_bound, middle
445-
lower_right, upper_right = middle, node.children_upper_bound
446-
else: # NEG
447-
lower_left, upper_left = middle, node.children_upper_bound
448-
lower_right, upper_right = node.children_lower_bound, middle
449-
left_child_node.set_children_bounds(lower_left, upper_left)
450-
right_child_node.set_children_bounds(lower_right, upper_right)
442+
MonotonicConstraint.NO_CST):
443+
lower_left = lower_right = node.children_lower_bound
444+
upper_left = upper_right = node.children_upper_bound
445+
else:
446+
mid = (left_child_node.value + right_child_node.value) / 2
447+
if (self.monotonic_cst[node.split_info.feature_idx] ==
448+
MonotonicConstraint.POS):
449+
lower_left, upper_left = node.children_lower_bound, mid
450+
lower_right, upper_right = mid, node.children_upper_bound
451+
else: # NEG
452+
lower_left, upper_left = mid, node.children_upper_bound
453+
lower_right, upper_right = node.children_lower_bound, mid
454+
left_child_node.set_children_bounds(lower_left, upper_left)
455+
right_child_node.set_children_bounds(lower_right, upper_right)
451456

452457
# Compute histograms of children, and compute their best possible split
453458
# (if needed)

0 commit comments

Comments
 (0)