Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cbeb272

Browse files
committed
record boundary info in nodes
1 parent 684727d commit cbeb272

File tree

3 files changed

+31
-12
lines changed

3 files changed

+31
-12
lines changed

sklearn/ensemble/_gb.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def _update_terminal_regions(
140140
sample_mask,
141141
learning_rate=0.1,
142142
k=0,
143-
line_search=True,
144143
):
145144
"""Update the leaf values to be predicted by the tree and raw_prediction.
146145
@@ -183,14 +182,11 @@ def _update_terminal_regions(
183182
``learning_rate``.
184183
k : int, default=0
185184
The index of the estimator being updated.
186-
line_search : bool, default=True
187-
Whether line search must be performed. Line search must not be
188-
performed under monotonic constraints.
189185
"""
190186
# compute leaf for each sample in ``X``.
191187
terminal_regions = tree.apply(X)
192188

193-
if line_search and not isinstance(loss, HalfSquaredError):
189+
if not isinstance(loss, HalfSquaredError):
194190
# mask all which are not in sample mask.
195191
masked_terminal_regions = terminal_regions.copy()
196192
masked_terminal_regions[~sample_mask] = -1
@@ -262,6 +258,11 @@ def compute_update(y_, indices, neg_gradient, raw_prediction, k):
262258
sw = None if sample_weight is None else sample_weight[indices]
263259
update = compute_update(y_, indices, neg_gradient, raw_prediction, k)
264260

261+
if update > tree.upper_bound[leaf]:
262+
update = tree.upper_bound[leaf]
263+
elif update < tree.lower_bound[leaf]:
264+
update = tree.lower_bound[leaf]
265+
265266
# TODO: Multiply here by learning rate instead of everywhere else.
266267
tree.value[leaf, 0, 0] = update
267268

@@ -515,7 +516,6 @@ def _fit_stage(
515516
sample_mask,
516517
learning_rate=self.learning_rate,
517518
k=k,
518-
line_search=self.monotonic_cst is None,
519519
)
520520

521521
# add tree to ensemble

sklearn/tree/_tree.pxd

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ cdef struct Node:
2222
intp_t n_node_samples # Number of samples at the node
2323
float64_t weighted_n_node_samples # Weighted number of samples at the node
2424
uint8_t missing_go_to_left # Whether features have missing values
25-
25+
float64_t lower_bound # Lower bound of the node's impurity
26+
float64_t upper_bound # Upper bound of the node's impurity
2627

2728
cdef struct ParentInfo:
2829
# Structure to store information about the parent of a node
@@ -58,7 +59,9 @@ cdef class Tree:
5859
intp_t feature, float64_t threshold, float64_t impurity,
5960
intp_t n_node_samples,
6061
float64_t weighted_n_node_samples,
61-
uint8_t missing_go_to_left) except -1 nogil
62+
uint8_t missing_go_to_left,
63+
float64_t lower_bound,
64+
float64_t upper_bound) except -1 nogil
6265
cdef int _resize(self, intp_t capacity) except -1 nogil
6366
cdef int _resize_c(self, intp_t capacity=*) except -1 nogil
6467

sklearn/tree/_tree.pyx

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,9 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
268268
node_id = tree._add_node(parent, is_left, is_leaf, split.feature,
269269
split.threshold, parent_record.impurity,
270270
n_node_samples, weighted_n_node_samples,
271-
split.missing_go_to_left)
271+
split.missing_go_to_left,
272+
parent_record.lower_bound,
273+
parent_record.upper_bound)
272274

273275
if node_id == INTPTR_MAX:
274276
rc = -1
@@ -626,7 +628,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
626628
is_left, is_leaf,
627629
split.feature, split.threshold, parent_record.impurity,
628630
n_node_samples, weighted_n_node_samples,
629-
split.missing_go_to_left)
631+
split.missing_go_to_left,
632+
parent_record.lower_bound, parent_record.upper_bound)
630633
if node_id == INTPTR_MAX:
631634
return -1
632635

@@ -774,6 +777,14 @@ cdef class Tree:
774777
def missing_go_to_left(self):
775778
return self._get_node_ndarray()['missing_go_to_left'][:self.node_count]
776779

780+
@property
781+
def lower_bound(self):
782+
return self._get_node_ndarray()['lower_bound'][:self.node_count]
783+
784+
@property
785+
def upper_bound(self):
786+
return self._get_node_ndarray()['upper_bound'][:self.node_count]
787+
777788
@property
778789
def value(self):
779790
return self._get_value_ndarray()[:self.node_count]
@@ -910,7 +921,9 @@ cdef class Tree:
910921
intp_t feature, float64_t threshold, float64_t impurity,
911922
intp_t n_node_samples,
912923
float64_t weighted_n_node_samples,
913-
uint8_t missing_go_to_left) except -1 nogil:
924+
uint8_t missing_go_to_left,
925+
float64_t lower_bound,
926+
float64_t upper_bound) except -1 nogil:
914927
"""Add a node to the tree.
915928
916929
The new node registers itself as the child of its parent.
@@ -927,6 +940,8 @@ cdef class Tree:
927940
node.impurity = impurity
928941
node.n_node_samples = n_node_samples
929942
node.weighted_n_node_samples = weighted_n_node_samples
943+
node.lower_bound = lower_bound
944+
node.upper_bound = upper_bound
930945

931946
if parent != _TREE_UNDEFINED:
932947
if is_left:
@@ -1934,7 +1949,8 @@ cdef _build_pruned_tree(
19341949
new_node_id = tree._add_node(
19351950
parent, is_left, is_leaf, node.feature, node.threshold,
19361951
node.impurity, node.n_node_samples,
1937-
node.weighted_n_node_samples, node.missing_go_to_left)
1952+
node.weighted_n_node_samples, node.missing_go_to_left,
1953+
node.lower_bound, node.upper_bound)
19381954

19391955
if new_node_id == INTPTR_MAX:
19401956
rc = -1

0 commit comments

Comments
 (0)