From 61e28ec5eaa02d9f38e62cc3e9e6732f0ebca183 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Thu, 22 Dec 2016 00:09:23 -0800 Subject: [PATCH 01/13] Remove inverted dependence of _utils.pyx on _tree.pyx because it was causing all kinds of problems. Now safe_realloc requires the item size to be explicitly provided. Also, it can allocate arrays of pointers to any type by casting to void*. --- sklearn/ensemble/_gradient_boosting.pyx | 8 ++++---- sklearn/tree/_criterion.pyx | 4 ++-- sklearn/tree/_splitter.pyx | 14 +++++++------- sklearn/tree/_tree.pyx | 14 +++++++------- sklearn/tree/_utils.pxd | 8 ++++---- sklearn/tree/_utils.pyx | 22 +++++++++++----------- 6 files changed, 35 insertions(+), 35 deletions(-) diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index 71371f5c24a48..63a01f63fc851 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -130,8 +130,8 @@ def _predict_regression_tree_stages_sparse(np.ndarray[object, ndim=2] estimators cdef Tree tree cdef Node** nodes = NULL cdef double** values = NULL - safe_realloc(&nodes, n_stages * n_outputs) - safe_realloc(&values, n_stages * n_outputs) + safe_realloc(&nodes, n_stages * n_outputs, sizeof(void*)) + safe_realloc(&values, n_stages * n_outputs, sizeof(void*)) for stage_i in range(n_stages): for output_i in range(n_outputs): tree = estimators[stage_i, output_i].tree_ @@ -147,8 +147,8 @@ def _predict_regression_tree_stages_sparse(np.ndarray[object, ndim=2] estimators # which features are nonzero in the present sample. cdef SIZE_t* feature_to_sample = NULL - safe_realloc(&X_sample, n_features) - safe_realloc(&feature_to_sample, n_features) + safe_realloc(&X_sample, n_features, sizeof(DTYPE_t)) + safe_realloc(&feature_to_sample, n_features, sizeof(SIZE_t)) memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 5187a5066bb2e..5fb1c459e34c0 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -249,7 +249,7 @@ cdef class ClassificationCriterion(Criterion): self.sum_right = NULL self.n_classes = NULL - safe_realloc(&self.n_classes, n_outputs) + safe_realloc(&self.n_classes, n_outputs, sizeof(SIZE_t)) cdef SIZE_t k = 0 cdef SIZE_t sum_stride = 0 @@ -1040,7 +1040,7 @@ cdef class MAE(RegressionCriterion): self.node_medians = NULL # Allocate memory for the accumulators - safe_realloc(&self.node_medians, n_outputs) + safe_realloc(&self.node_medians, n_outputs, sizeof(DOUBLE_t)) self.left_child = np.empty(n_outputs, dtype='object') self.right_child = np.empty(n_outputs, dtype='object') diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 06dfab587493c..cc40343e00d66 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -147,7 +147,7 @@ cdef class Splitter: # Create a new array which will be used to store nonzero # samples from the feature of interest - cdef SIZE_t* samples = safe_realloc(&self.samples, n_samples) + cdef SIZE_t* samples = safe_realloc(&self.samples, n_samples, sizeof(SIZE_t)) cdef SIZE_t i, j cdef double weighted_n_samples = 0.0 @@ -169,15 +169,15 @@ cdef class Splitter: self.weighted_n_samples = weighted_n_samples cdef SIZE_t n_features = X.shape[1] - cdef SIZE_t* features = safe_realloc(&self.features, n_features) + cdef SIZE_t* features = safe_realloc(&self.features, n_features, sizeof(SIZE_t)) for i in range(n_features): features[i] = i self.n_features = n_features - safe_realloc(&self.feature_values, n_samples) - safe_realloc(&self.constant_features, n_features) + safe_realloc(&self.feature_values, n_samples, sizeof(DTYPE_t)) + safe_realloc(&self.constant_features, n_features, sizeof(SIZE_t)) self.y = y.data self.y_stride = y.strides[0] / y.itemsize @@ -295,7 +295,7 @@ cdef class BaseDenseSplitter(Splitter): self.X_idx_sorted.itemsize) self.n_total_samples = X.shape[0] - safe_realloc(&self.sample_mask, self.n_total_samples) + safe_realloc(&self.sample_mask, self.n_total_samples, sizeof(SIZE_t)) memset(self.sample_mask, 0, self.n_total_samples*sizeof(SIZE_t)) return 0 @@ -924,8 +924,8 @@ cdef class BaseSparseSplitter(Splitter): self.n_total_samples = n_total_samples # Initialize auxiliary array used to perform split - safe_realloc(&self.index_to_samples, n_total_samples) - safe_realloc(&self.sorted_samples, n_samples) + safe_realloc(&self.index_to_samples, n_total_samples, sizeof(SIZE_t)) + safe_realloc(&self.sorted_samples, n_samples, sizeof(SIZE_t)) cdef SIZE_t* index_to_samples = self.index_to_samples cdef SIZE_t p diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 69ab8572d2ae5..fd2db0a814c47 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -590,7 +590,7 @@ cdef class Tree: self.n_features = n_features self.n_outputs = n_outputs self.n_classes = NULL - safe_realloc(&self.n_classes, n_outputs) + safe_realloc(&self.n_classes, n_outputs, sizeof(SIZE_t)) self.max_n_classes = np.max(n_classes) self.value_stride = n_outputs * self.max_n_classes @@ -688,8 +688,8 @@ cdef class Tree: else: capacity = 2 * self.capacity - safe_realloc(&self.nodes, capacity) - safe_realloc(&self.value, capacity * self.value_stride) + safe_realloc(&self.nodes, capacity, sizeof(Node)) + safe_realloc(&self.value, capacity * self.value_stride, sizeof(double)) # value memory is initialised to 0 to enable classifier argmax if capacity > self.capacity: @@ -843,8 +843,8 @@ cdef class Tree: # which features are nonzero in the present sample. cdef SIZE_t* feature_to_sample = NULL - safe_realloc(&X_sample, n_features) - safe_realloc(&feature_to_sample, n_features) + safe_realloc(&X_sample, n_features, sizeof(DTYPE_t)) + safe_realloc(&feature_to_sample, n_features, sizeof(SIZE_t)) with nogil: memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) @@ -989,8 +989,8 @@ cdef class Tree: # which features are nonzero in the present sample. cdef SIZE_t* feature_to_sample = NULL - safe_realloc(&X_sample, n_features) - safe_realloc(&feature_to_sample, n_features) + safe_realloc(&X_sample, n_features, sizeof(DTYPE_t)) + safe_realloc(&feature_to_sample, n_features, sizeof(SIZE_t)) with nogil: memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 017888ab41db7..c7f78f8ff2253 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -10,7 +10,6 @@ import numpy as np cimport numpy as np -from _tree cimport Node ctypedef np.npy_float32 DTYPE_t # Type of X ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight @@ -18,6 +17,8 @@ ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +cdef struct Node # Forward declaration + cdef enum: # Max value for our rand_r replacement (near the bottom). # We don't use RAND_MAX because it's different across platforms and @@ -37,13 +38,12 @@ ctypedef fused realloc_ptr: (unsigned char*) (WeightedPQueueRecord*) (DOUBLE_t*) - (DOUBLE_t**) (Node*) - (Node**) (StackRecord*) (PriorityHeapRecord*) + (void**) -cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) nogil except * +cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems, size_t elem_bytes) nogil except * cdef np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size) diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index faf2e5b777448..f517a317f1743 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -24,15 +24,15 @@ np.import_array() # Helper functions # ============================================================================= -cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) nogil except *: +cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems, size_t nbytes_elem) nogil except *: # sizeof(realloc_ptr[0]) would be more like idiomatic C, but causes Cython # 0.20.1 to crash. - cdef size_t nbytes = nelems * sizeof(p[0][0]) - if nbytes / sizeof(p[0][0]) != nelems: + cdef size_t nbytes = nelems * nbytes_elem + if nbytes / nbytes_elem != nelems: # Overflow in the multiplication with gil: raise MemoryError("could not allocate (%d * %d) bytes" - % (nelems, sizeof(p[0][0]))) + % (nelems, nbytes_elem)) cdef realloc_ptr tmp = realloc(p[0], nbytes) if tmp == NULL: with gil: @@ -46,7 +46,7 @@ def _realloc_test(): # Helper for tests. Tries to allocate (-1) / 2 * sizeof(size_t) # bytes, which will always overflow. cdef SIZE_t* p = NULL - safe_realloc(&p, (-1) / 2) + safe_realloc(&p, (-1) / 2, sizeof(SIZE_t)) if p != NULL: free(p) assert False @@ -132,7 +132,7 @@ cdef class Stack: if top >= self.capacity: self.capacity *= 2 # Since safe_realloc can raise MemoryError, use `except -1` - safe_realloc(&self.stack_, self.capacity) + safe_realloc(&self.stack_, self.capacity, sizeof(StackRecord)) stack = self.stack_ stack[top].start = start @@ -192,7 +192,7 @@ cdef class PriorityHeap: def __cinit__(self, SIZE_t capacity): self.capacity = capacity self.heap_ptr = 0 - safe_realloc(&self.heap_, capacity) + safe_realloc(&self.heap_, capacity, sizeof(PriorityHeapRecord)) def __dealloc__(self): free(self.heap_) @@ -248,7 +248,7 @@ cdef class PriorityHeap: if heap_ptr >= self.capacity: self.capacity *= 2 # Since safe_realloc can raise MemoryError, use `except -1` - safe_realloc(&self.heap_, self.capacity) + safe_realloc(&self.heap_, self.capacity, sizeof(PriorityHeapRecord)) # Put element as last element of heap heap = self.heap_ @@ -318,7 +318,7 @@ cdef class WeightedPQueue: def __cinit__(self, SIZE_t capacity): self.capacity = capacity self.array_ptr = 0 - safe_realloc(&self.array_, capacity) + safe_realloc(&self.array_, capacity, sizeof(WeightedPQueueRecord)) def __dealloc__(self): free(self.array_) @@ -331,7 +331,7 @@ cdef class WeightedPQueue: """ self.array_ptr = 0 # Since safe_realloc can raise MemoryError, use `except *` - safe_realloc(&self.array_, self.capacity) + safe_realloc(&self.array_, self.capacity, sizeof(WeightedPQueueRecord)) return 0 cdef bint is_empty(self) nogil: @@ -354,7 +354,7 @@ cdef class WeightedPQueue: if array_ptr >= self.capacity: self.capacity *= 2 # Since safe_realloc can raise MemoryError, use `except -1` - safe_realloc(&self.array_, self.capacity) + safe_realloc(&self.array_, self.capacity, sizeof(WeightedPQueueRecord)) # Put element as last element of array array = self.array_ From 926d48f758f1b1437cf1d8849e13feccc9b4c002 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Wed, 5 Oct 2016 22:59:15 -0700 Subject: [PATCH 02/13] Tree constructor now checks for mismatched struct sizes. --- sklearn/tree/_tree.pyx | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index fd2db0a814c47..9151ae8cb3437 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -599,6 +599,14 @@ cdef class Tree: for k in range(n_outputs): self.n_classes[k] = n_classes[k] + # Ensure cython and numpy node sizes match up + np_node_size = ( NODE_DTYPE).itemsize + node_size = sizeof(Node) + if (np_node_size != node_size): + raise TypeError('Size of numpy NODE_DTYPE ({} bytes) does not' + ' match size of Node ({} bytes)'.format( + np_node_size, node_size)) + # Inner structures self.max_depth = 0 self.node_count = 0 From 82e932e4dcb31f1e6ffe9890fef9d4ada997d60a Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Fri, 7 Oct 2016 23:03:44 -0700 Subject: [PATCH 03/13] Created SplitValue datatype to generalize the concept of a threshold to categorical variables. Replaced the threshold attribute of SplitRecord and Node with SplitValue. --- sklearn/ensemble/_gradient_boosting.pyx | 6 ++-- sklearn/tree/_splitter.pxd | 5 ++- sklearn/tree/_splitter.pyx | 44 ++++++++++++------------- sklearn/tree/_tree.pxd | 5 ++- sklearn/tree/_tree.pyx | 29 +++++++++------- sklearn/tree/_utils.pxd | 21 ++++++++++++ 6 files changed, 71 insertions(+), 39 deletions(-) diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index 63a01f63fc851..43bced4b46742 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -93,7 +93,7 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, node = root_node # While node not a leaf while node.left_child != TREE_LEAF: - if X[i * n_features + node.feature] <= node.threshold: + if X[i * n_features + node.feature] <= node.split_value.threshold: node = root_node + node.left_child else: node = root_node + node.right_child @@ -174,7 +174,7 @@ def _predict_regression_tree_stages_sparse(np.ndarray[object, ndim=2] estimators else: feature_value = 0. - if feature_value <= node.threshold: + if feature_value <= node.split_value.threshold: node = root_node + node.left_child else: node = root_node + node.right_child @@ -322,7 +322,7 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, if feature_index != -1: # split feature in target set # push left or right child on stack - if X[i, feature_index] <= current_node.threshold: + if X[i, feature_index] <= current_node.split_value.threshold: # left node_stack[stack_size] = (root_node + current_node.left_child) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 4d5c5ae46bceb..8c3d3f47f63c1 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -12,6 +12,8 @@ import numpy as np cimport numpy as np +from ._utils cimport SplitValue + from ._criterion cimport Criterion ctypedef np.npy_float32 DTYPE_t # Type of X @@ -26,7 +28,8 @@ cdef struct SplitRecord: SIZE_t pos # Split samples array at the given position, # i.e. count of samples below threshold for feature. # pos is >= end if the node is a leaf. - double threshold # Threshold to split at. + SplitValue split_value # Generalized threshold for categorical and + # non-categorical features double improvement # Impurity improvement given parent node. double impurity_left # Impurity of the left split. double impurity_right # Impurity of the right split. diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index cc40343e00d66..9f119079a7d98 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -48,7 +48,7 @@ cdef inline void _init_split(SplitRecord* self, SIZE_t start_pos) nogil: self.impurity_right = INFINITY self.pos = start_pos self.feature = 0 - self.threshold = 0. + self.split_value.threshold = 0. self.improvement = -INFINITY cdef class Splitter: @@ -481,10 +481,10 @@ cdef class BestSplitter(BaseDenseSplitter): if current_proxy_improvement > best_proxy_improvement: best_proxy_improvement = current_proxy_improvement - current.threshold = (Xf[p - 1] + Xf[p]) / 2.0 + current.split_value.threshold = (Xf[p - 1] + Xf[p]) / 2.0 - if current.threshold == Xf[p]: - current.threshold = Xf[p - 1] + if current.split_value.threshold == Xf[p]: + current.split_value.threshold = Xf[p - 1] best = current # copy @@ -495,7 +495,7 @@ cdef class BestSplitter(BaseDenseSplitter): p = start while p < partition_end: - if X[X_sample_stride * samples[p] + feature_offset] <= best.threshold: + if X[X_sample_stride * samples[p] + feature_offset] <= best.split_value.threshold: p += 1 else: @@ -776,19 +776,18 @@ cdef class RandomSplitter(BaseDenseSplitter): features[f_i], features[f_j] = features[f_j], features[f_i] # Draw a random threshold - current.threshold = rand_uniform(min_feature_value, - max_feature_value, - random_state) + current.split_value.threshold = rand_uniform( + min_feature_value, max_feature_value, random_state) - if current.threshold == max_feature_value: - current.threshold = min_feature_value + if current.split_value.threshold == max_feature_value: + current.split_value.threshold = min_feature_value # Partition partition_end = end p = start while p < partition_end: current_feature_value = Xf[p] - if current_feature_value <= current.threshold: + if current_feature_value <= current.split_value.threshold: p += 1 else: partition_end -= 1 @@ -830,7 +829,7 @@ cdef class RandomSplitter(BaseDenseSplitter): p = start while p < partition_end: - if X[X_sample_stride * samples[p] + feature_stride] <= best.threshold: + if X[X_sample_stride * samples[p] + feature_stride] <= best.split_value.threshold: p += 1 else: @@ -1381,9 +1380,9 @@ cdef class BestSparseSplitter(BaseSparseSplitter): if current_proxy_improvement > best_proxy_improvement: best_proxy_improvement = current_proxy_improvement - current.threshold = (Xf[p_prev] + Xf[p]) / 2.0 - if current.threshold == Xf[p]: - current.threshold = Xf[p_prev] + current.split_value.threshold = (Xf[p_prev] + Xf[p]) / 2.0 + if current.split_value.threshold == Xf[p]: + current.split_value.threshold = Xf[p_prev] best = current @@ -1392,7 +1391,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): self.extract_nnz(best.feature, &end_negative, &start_positive, &is_samples_sorted) - self._partition(best.threshold, end_negative, start_positive, + self._partition(best.split_value.threshold, end_negative, start_positive, best.pos) self.criterion.reset() @@ -1579,15 +1578,14 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): features[f_i], features[f_j] = features[f_j], features[f_i] # Draw a random threshold - current.threshold = rand_uniform(min_feature_value, - max_feature_value, - random_state) + current.split_value.threshold = rand_uniform( + min_feature_value, max_feature_value, random_state) - if current.threshold == max_feature_value: - current.threshold = min_feature_value + if current.split_value.threshold == max_feature_value: + current.split_value.threshold = min_feature_value # Partition - current.pos = self._partition(current.threshold, + current.pos = self._partition(current.split_value.threshold, end_negative, start_positive, start_positive + @@ -1623,7 +1621,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): self.extract_nnz(best.feature, &end_negative, &start_positive, &is_samples_sorted) - self._partition(best.threshold, end_negative, start_positive, + self._partition(best.split_value.threshold, end_negative, start_positive, best.pos) self.criterion.reset() diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 4f9f359725646..e3febde5e4d76 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -19,6 +19,8 @@ ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +from ._utils cimport SplitValue + from ._splitter cimport Splitter from ._splitter cimport SplitRecord @@ -28,7 +30,8 @@ cdef struct Node: SIZE_t left_child # id of the left child of the node SIZE_t right_child # id of the right child of the node SIZE_t feature # Feature used for splitting the node - DOUBLE_t threshold # Threshold value at the node + SplitValue split_value # Generalized threshold for categorical and + # non-categorical features DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) SIZE_t n_node_samples # Number of samples at the node DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 9151ae8cb3437..812ae5ec805eb 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -65,6 +65,14 @@ cdef SIZE_t _TREE_UNDEFINED = TREE_UNDEFINED cdef SIZE_t INITIAL_STACK_SIZE = 10 # Repeat struct definition for numpy +# NOTE: SPLITVALUE_DTYPE cannot be used with numpy < v1.7 +# Recommend replacing threshold with it when support for v1.6 is dropped + +##SPLITVALUE_DTYPE = np.dtype({ +## 'names': ['threshold', 'cat_split'], +## 'formats': [np.float64, np.uint64], +## 'offsets': [0, 0] +##}) NODE_DTYPE = np.dtype({ 'names': ['left_child', 'right_child', 'feature', 'threshold', 'impurity', 'n_node_samples', 'weighted_n_node_samples'], @@ -74,7 +82,7 @@ NODE_DTYPE = np.dtype({ &( NULL).left_child, &( NULL).right_child, &( NULL).feature, - &( NULL).threshold, + &( NULL).split_value, &( NULL).impurity, &( NULL).n_node_samples, &( NULL).weighted_n_node_samples @@ -182,7 +190,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef SplitRecord split cdef SIZE_t node_id - cdef double threshold cdef double impurity = INFINITY cdef SIZE_t n_constant_features cdef bint is_leaf @@ -232,7 +239,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): is_leaf = is_leaf or (split.pos >= end) node_id = tree._add_node(parent, is_left, is_leaf, split.feature, - split.threshold, impurity, n_node_samples, + split.split_value.threshold, impurity, n_node_samples, weighted_n_node_samples) if node_id == (-1): @@ -363,7 +370,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): node.left_child = _TREE_LEAF node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED - node.threshold = _TREE_UNDEFINED + node.split_value.threshold = _TREE_UNDEFINED else: # Node is expandable @@ -452,7 +459,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if parent != NULL else _TREE_UNDEFINED, is_left, is_leaf, - split.feature, split.threshold, impurity, n_node_samples, + split.feature, split.split_value.threshold, impurity, n_node_samples, weighted_n_node_samples) if node_id == (-1): return -1 @@ -743,12 +750,12 @@ cdef class Tree: node.left_child = _TREE_LEAF node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED - node.threshold = _TREE_UNDEFINED + node.split_value.threshold = _TREE_UNDEFINED else: # left_child and right_child will be set later node.feature = feature - node.threshold = threshold + node.split_value.threshold = threshold self.node_count += 1 @@ -802,7 +809,7 @@ cdef class Tree: while node.left_child != _TREE_LEAF: # ... and node.right_child != _TREE_LEAF: if X_ptr[X_sample_stride * i + - X_fx_stride * node.feature] <= node.threshold: + X_fx_stride * node.feature] <= node.split_value.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -873,7 +880,7 @@ cdef class Tree: else: feature_value = 0. - if feature_value <= node.threshold: + if feature_value <= node.split_value.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -936,7 +943,7 @@ cdef class Tree: indptr_ptr[i + 1] += 1 if X_ptr[X_sample_stride * i + - X_fx_stride * node.feature] <= node.threshold: + X_fx_stride * node.feature] <= node.split_value.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -1024,7 +1031,7 @@ cdef class Tree: else: feature_value = 0. - if feature_value <= node.threshold: + if feature_value <= node.split_value.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index c7f78f8ff2253..214b004f3487e 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -16,6 +16,27 @@ ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer + +ctypedef union SplitValue: + # Union type to generalize the concept of a threshold to + # categorical features. For non-categorical features, use the + # threshold member. It acts just as before, where feature values + # less than or equal to the threshold go left, and values greater + # than the threshold go right. + # + # For categorical features, use the cat_split member. It works in + # one of two ways, indicated by the value of its least significant + # bit (LSB). If the LSB is 0, then cat_split acts as a bitfield + # for up to 64 categories, sending samples left if the bit + # corresponding to their category is 1 or right if it is 0. If the + # LSB is 1, then the more significant 32 bits of cat_split is a + # random seed. To evaluate a sample, use the random seed to flip a + # coin (category_value + 1) times and send it left if the last + # flip gives 1; otherwise right. This second method allows up to + # 2**31 category values, but can only be used for RandomSplitter. + DOUBLE_t threshold + UINT64_t cat_split cdef struct Node # Forward declaration From 5cfa6c2ad303c790cffda55270386e9ce710d09d Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Sat, 8 Oct 2016 14:22:46 -0700 Subject: [PATCH 04/13] Added attribute n_categories to Splitter and Tree, an array of ints that defaults to -1 for each feature (indicating non-categorical). --- sklearn/tree/_splitter.pxd | 3 +++ sklearn/tree/_splitter.pyx | 21 +++++++++++++++++++-- sklearn/tree/_tree.pxd | 3 +++ sklearn/tree/_tree.pyx | 37 +++++++++++++++++++++++++++++++++---- sklearn/tree/_utils.pxd | 2 ++ sklearn/tree/_utils.pyx | 7 +++++++ sklearn/tree/tree.py | 9 +++++++-- 7 files changed, 74 insertions(+), 8 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 8c3d3f47f63c1..5373e31dfa985 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -66,6 +66,8 @@ cdef class Splitter: cdef DOUBLE_t* y cdef SIZE_t y_stride cdef DOUBLE_t* sample_weight + cdef INT32_t *n_categories # (n_features,) array giving number of + # categories (<0 for non-categorical) # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, @@ -86,6 +88,7 @@ cdef class Splitter: # Methods cdef int init(self, object X, np.ndarray y, DOUBLE_t* sample_weight, + INT32_t* n_categories, np.ndarray X_idx_sorted=*) except -1 cdef int node_reset(self, SIZE_t start, SIZE_t end, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 9f119079a7d98..644c5a4d95e58 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -95,6 +95,7 @@ cdef class Splitter: self.y = NULL self.y_stride = 0 self.sample_weight = NULL + self.n_categories = NULL self.max_features = max_features self.min_samples_leaf = min_samples_leaf @@ -109,6 +110,7 @@ cdef class Splitter: free(self.features) free(self.constant_features) free(self.feature_values) + free(self.n_categories) def __getstate__(self): return {} @@ -120,6 +122,7 @@ cdef class Splitter: object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, + INT32_t* n_categories, np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter. @@ -140,6 +143,10 @@ cdef class Splitter: The weights of the samples, where higher weighted samples are fit closer than lower weight samples. If not provided, all samples are assumed to have uniform weight. + + n_categories : array of INT32_t, shape=(n_features,) + Number of categories for categorical features, or -1 for + non-categorical features """ self.rand_r_state = self.random_state.randint(0, RAND_R_MAX) @@ -183,6 +190,14 @@ cdef class Splitter: self.y_stride = y.strides[0] / y.itemsize self.sample_weight = sample_weight + + # Initialize the number of categories for each feature + # A value of -1 indicates a non-categorical feature + safe_realloc(&self.n_categories, n_features, sizeof(INT32_t)) + for i in range(n_features): + self.n_categories[i] = (-1 if n_categories == NULL + else n_categories[i]) + return 0 cdef int node_reset(self, SIZE_t start, SIZE_t end, @@ -271,6 +286,7 @@ cdef class BaseDenseSplitter(Splitter): object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, + INT32_t* n_categories, np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter @@ -279,7 +295,7 @@ cdef class BaseDenseSplitter(Splitter): """ # Call parent init - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X, y, sample_weight, n_categories) # Initialize X cdef np.ndarray X_ndarray = X @@ -896,6 +912,7 @@ cdef class BaseSparseSplitter(Splitter): object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, + INT32_t* n_categories, np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter @@ -903,7 +920,7 @@ cdef class BaseSparseSplitter(Splitter): or 0 otherwise. """ # Call parent init - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X, y, sample_weight, n_categories) if not isinstance(X, csc_matrix): raise ValueError("X should be in csc format") diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index e3febde5e4d76..75f2775ab90e8 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -56,6 +56,8 @@ cdef class Tree: cdef Node* nodes # Array of nodes cdef double* value # (capacity, n_outputs, max_n_classes) array of values cdef SIZE_t value_stride # = n_outputs * max_n_classes + cdef INT32_t *n_categories # (n_features,) array giving number of + # categories (<0 for non-categorical) # Methods cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, @@ -103,5 +105,6 @@ cdef class TreeBuilder: cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=*, + np.ndarray n_categories=*, np.ndarray X_idx_sorted=*) cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 812ae5ec805eb..268ad422bc508 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -36,6 +36,7 @@ from ._utils cimport PriorityHeap from ._utils cimport PriorityHeapRecord from ._utils cimport safe_realloc from ._utils cimport sizet_ptr_to_ndarray +from ._utils cimport int32_ptr_to_ndarray cdef extern from "numpy/arrayobject.h": object PyArray_NewFromDescr(object subtype, np.dtype descr, @@ -98,6 +99,7 @@ cdef class TreeBuilder: cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, + np.ndarray n_categories=None, np.ndarray X_idx_sorted=None): """Build a decision tree from the training set (X, y).""" pass @@ -148,6 +150,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, + np.ndarray n_categories=None, np.ndarray X_idx_sorted=None): """Build a decision tree from the training set (X, y).""" @@ -158,6 +161,11 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if sample_weight is not None: sample_weight_ptr = sample_weight.data + cdef INT32_t *n_categories_ptr = NULL + if n_categories is not None: + n_categories = np.asarray(n_categories, dtype=np.int32, order='C') + n_categories_ptr = n_categories.data + # Initial capacity cdef int init_capacity @@ -177,7 +185,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double min_impurity_split = self.min_impurity_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr, X_idx_sorted) + splitter.init(X, y, sample_weight_ptr, n_categories_ptr, X_idx_sorted) cdef SIZE_t start cdef SIZE_t end @@ -311,6 +319,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, + np.ndarray n_categories=None, np.ndarray X_idx_sorted=None): """Build a decision tree from the training set (X, y).""" @@ -321,6 +330,11 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if sample_weight is not None: sample_weight_ptr = sample_weight.data + cdef INT32_t *n_categories_ptr = NULL + if n_categories is not None: + n_categories = np.asarray(n_categories, dtype=np.int32, order='C') + n_categories_ptr = n_categories.data + # Parameters cdef Splitter splitter = self.splitter cdef SIZE_t max_leaf_nodes = self.max_leaf_nodes @@ -329,7 +343,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_split = self.min_samples_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr, X_idx_sorted) + splitter.init(X, y, sample_weight_ptr, n_categories_ptr, X_idx_sorted) cdef PriorityHeap frontier = PriorityHeap(INITIAL_STACK_SIZE) cdef PriorityHeapRecord record @@ -539,6 +553,10 @@ cdef class Tree: value : array of double, shape [node_count, n_outputs, max_n_classes] Contains the constant prediction value of each node. + n_categories : array of int32, shape [n_features] + Number of expected category values for categorical features, or + -1 for non-categorical features. + impurity : array of double, shape [node_count] impurity[i] holds the impurity (i.e., the value of the splitting criterion) at node i. @@ -590,14 +608,20 @@ cdef class Tree: def __get__(self): return self._get_value_ndarray()[:self.node_count] + property n_categories: + def __get__(self): + return int32_ptr_to_ndarray(self.n_categories, self.n_features).copy() + def __cinit__(self, int n_features, np.ndarray[SIZE_t, ndim=1] n_classes, - int n_outputs): + int n_outputs, np.ndarray[INT32_t, ndim=1] n_categories): """Constructor.""" # Input/Output layout self.n_features = n_features self.n_outputs = n_outputs self.n_classes = NULL + self.n_categories = NULL safe_realloc(&self.n_classes, n_outputs, sizeof(SIZE_t)) + safe_realloc(&self.n_categories, n_features, sizeof(INT32_t)) self.max_n_classes = np.max(n_classes) self.value_stride = n_outputs * self.max_n_classes @@ -605,6 +629,8 @@ cdef class Tree: cdef SIZE_t k for k in range(n_outputs): self.n_classes[k] = n_classes[k] + for k in range(n_features): + self.n_categories[k] = n_categories[k] # Ensure cython and numpy node sizes match up np_node_size = ( NODE_DTYPE).itemsize @@ -627,12 +653,15 @@ cdef class Tree: free(self.n_classes) free(self.value) free(self.nodes) + free(self.n_categories) def __reduce__(self): """Reduce re-implementation, for pickling.""" return (Tree, (self.n_features, sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), - self.n_outputs), self.__getstate__()) + self.n_outputs, + int32_ptr_to_ndarray(self.n_categories, self.n_features)), + self.__getstate__()) def __getstate__(self): """Getstate re-implementation, for pickling.""" diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 214b004f3487e..07ea94f03e001 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -63,12 +63,14 @@ ctypedef fused realloc_ptr: (StackRecord*) (PriorityHeapRecord*) (void**) + (INT32_t*) cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems, size_t elem_bytes) nogil except * cdef np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size) +cdef np.ndarray int32_ptr_to_ndarray(INT32_t* data, SIZE_t size) cdef SIZE_t rand_int(SIZE_t low, SIZE_t high, UINT32_t* random_state) nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index f517a317f1743..01e93ecc7e189 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -69,6 +69,13 @@ cdef inline np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size): return np.PyArray_SimpleNewFromData(1, shape, np.NPY_INTP, data).copy() +cdef inline np.ndarray int32_ptr_to_ndarray(INT32_t* data, SIZE_t size): + """Encapsulate data into a 1D numpy array of int32's.""" + cdef np.npy_intp shape[1] + shape[0] = size + return np.PyArray_SimpleNewFromData(1, shape, np.NPY_INT32, data) + + cdef inline SIZE_t rand_int(SIZE_t low, SIZE_t high, UINT32_t* random_state) nogil: """Generate a random integer in [0; end).""" diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index f63537f4bfdeb..7a1fdbb5585a7 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -303,6 +303,9 @@ def fit(self, X, y, sample_weight=None, check_input=True, ".shape = {})".format(X.shape, X_idx_sorted.shape)) + # Set n_categories (hard-code -1 for now) + n_categories = np.array([-1] * self.n_features_, dtype=np.int32) + # Build tree criterion = self.criterion if not isinstance(criterion, Criterion): @@ -324,7 +327,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, random_state, self.presort) - self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_) + self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_, + n_categories) # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: @@ -340,7 +344,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, max_leaf_nodes, self.min_impurity_split) - builder.build(self.tree_, X, y, sample_weight, X_idx_sorted) + builder.build(self.tree_, X, y, sample_weight, n_categories, + X_idx_sorted) if self.n_outputs_ == 1: self.n_classes_ = self.n_classes_[0] From a65b8482a7f207f480430f943c7025afcf505442 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Thu, 20 Oct 2016 00:07:37 -0700 Subject: [PATCH 05/13] Added a goes_left function to replace threshold comparisons during prediction with trees. Also introduced category caches for quick evaluation of categorical splits. --- sklearn/tree/_tree.pxd | 11 +++++- sklearn/tree/_tree.pyx | 86 ++++++++++++++++++++++++++++++++++++----- sklearn/tree/_utils.pxd | 10 +++++ sklearn/tree/_utils.pyx | 40 +++++++++++++++++++ 4 files changed, 136 insertions(+), 11 deletions(-) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 75f2775ab90e8..4918532bde5ca 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -37,6 +37,15 @@ cdef struct Node: DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node +cdef class CategoryCacheMgr: + # Class to manage the category cache memory during Tree.apply() + + cdef SIZE_t n_nodes + cdef UINT32_t **bits + + cdef void populate(self, Node *nodes, SIZE_t n_nodes, INT32_t *n_categories) + + cdef class Tree: # The Tree object is a binary tree structure constructed by the # TreeBuilder. The tree structure is used for predictions and @@ -61,7 +70,7 @@ cdef class Tree: # Methods cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, - SIZE_t feature, double threshold, double impurity, + SIZE_t feature, SplitValue split_value, double impurity, SIZE_t n_node_samples, double weighted_n_samples) nogil except -1 cdef int _resize(self, SIZE_t capacity) nogil except -1 diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 268ad422bc508..78a7048af9801 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -37,6 +37,8 @@ from ._utils cimport PriorityHeapRecord from ._utils cimport safe_realloc from ._utils cimport sizet_ptr_to_ndarray from ._utils cimport int32_ptr_to_ndarray +from ._utils cimport setup_cat_cache +from ._utils cimport goes_left cdef extern from "numpy/arrayobject.h": object PyArray_NewFromDescr(object subtype, np.dtype descr, @@ -247,7 +249,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): is_leaf = is_leaf or (split.pos >= end) node_id = tree._add_node(parent, is_left, is_leaf, split.feature, - split.split_value.threshold, impurity, n_node_samples, + split.split_value, impurity, n_node_samples, weighted_n_node_samples) if node_id == (-1): @@ -473,7 +475,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if parent != NULL else _TREE_UNDEFINED, is_left, is_leaf, - split.feature, split.split_value.threshold, impurity, n_node_samples, + split.feature, split.split_value, impurity, n_node_samples, weighted_n_node_samples) if node_id == (-1): return -1 @@ -506,6 +508,40 @@ cdef class BestFirstTreeBuilder(TreeBuilder): return 0 +cdef class CategoryCacheMgr: + """Class to manage the category cache memory during Tree.apply() + """ + + def __cinit__(self): + self.n_nodes = 0 + self.bits = NULL + + def _dealloc__(self): + cdef int i + + if self.bits != NULL: + for i in range(self.n_nodes): + free(self.bits[i]) + free(self.bits) + + cdef void populate(self, Node *nodes, SIZE_t n_nodes, INT32_t *n_categories): + cdef SIZE_t i + cdef INT32_t ncat + + if nodes == NULL or n_categories == NULL: + return + + self.n_nodes = n_nodes + safe_realloc( &self.bits, n_nodes, sizeof(void *)) + for i in range(n_nodes): + self.bits[i] = NULL + if nodes[i].left_child != _TREE_LEAF: + ncat = n_categories[nodes[i].feature] + if ncat > 0: + safe_realloc(&self.bits[i], (ncat + 31) // 32, sizeof(UINT32_t)) + setup_cat_cache(self.bits[i], nodes[i].split_value.cat_split, ncat) + + # ============================================================================= # Tree # ============================================================================= @@ -749,7 +785,7 @@ cdef class Tree: return 0 cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, - SIZE_t feature, double threshold, double impurity, + SIZE_t feature, SplitValue split_value, double impurity, SIZE_t n_node_samples, double weighted_n_node_samples) nogil except -1: """Add a node to the tree. @@ -784,7 +820,7 @@ cdef class Tree: else: # left_child and right_child will be set later node.feature = feature - node.split_value.threshold = threshold + node.split_value = split_value self.node_count += 1 @@ -830,17 +866,24 @@ cdef class Tree: # Initialize auxiliary data-structure cdef Node* node = NULL cdef SIZE_t i = 0 + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef UINT32_t** cat_caches = cache_mgr.bits + cdef UINT32_t* cache = NULL with nogil: for i in range(n_samples): node = self.nodes + cache = cat_caches[0] # While node not a leaf while node.left_child != _TREE_LEAF: # ... and node.right_child != _TREE_LEAF: - if X_ptr[X_sample_stride * i + - X_fx_stride * node.feature] <= node.split_value.threshold: + if goes_left(X_ptr[X_sample_stride * i + X_fx_stride * node.feature], + node.split_value, self.n_categories[node.feature], cache): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] out_ptr[i] = (node - self.nodes) # node offset @@ -881,6 +924,10 @@ cdef class Tree: cdef DTYPE_t* X_sample = NULL cdef SIZE_t i = 0 cdef INT32_t k = 0 + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef UINT32_t** cat_caches = cache_mgr.bits + cdef UINT32_t* cache = NULL # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify @@ -895,6 +942,7 @@ cdef class Tree: for i in range(n_samples): node = self.nodes + cache = cat_caches[0] for k in range(X_indptr[i], X_indptr[i + 1]): feature_to_sample[X_indices[k]] = i @@ -909,9 +957,12 @@ cdef class Tree: else: feature_value = 0. - if feature_value <= node.split_value.threshold: + if goes_left(feature_value, node.split_value, + self.n_categories[node.feature], cache): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] out_ptr[i] = (node - self.nodes) # node offset @@ -959,10 +1010,15 @@ cdef class Tree: # Initialize auxiliary data-structure cdef Node* node = NULL cdef SIZE_t i = 0 + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef UINT32_t** cat_caches = cache_mgr.bits + cdef UINT32_t* cache = NULL with nogil: for i in range(n_samples): node = self.nodes + cache = cat_caches[0] indptr_ptr[i + 1] = indptr_ptr[i] # Add all external nodes @@ -971,10 +1027,12 @@ cdef class Tree: indices_ptr[indptr_ptr[i + 1]] = (node - self.nodes) indptr_ptr[i + 1] += 1 - if X_ptr[X_sample_stride * i + - X_fx_stride * node.feature] <= node.split_value.threshold: + if goes_left(X_ptr[X_sample_stride * i + X_fx_stride * node.feature], + node.split_value, self.n_categories[node.feature], cache): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] # Add the leave node @@ -1027,6 +1085,10 @@ cdef class Tree: cdef DTYPE_t* X_sample = NULL cdef SIZE_t i = 0 cdef INT32_t k = 0 + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef UINT32_t** cat_caches = cache_mgr.bits + cdef UINT32_t* cache = NULL # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify @@ -1041,6 +1103,7 @@ cdef class Tree: for i in range(n_samples): node = self.nodes + cache = cat_caches[0] indptr_ptr[i + 1] = indptr_ptr[i] for k in range(X_indptr[i], X_indptr[i + 1]): @@ -1060,9 +1123,12 @@ cdef class Tree: else: feature_value = 0. - if feature_value <= node.split_value.threshold: + if goes_left(feature_value, node.split_value, + self.n_categories[node.feature], cache): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] # Add the leave node diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 07ea94f03e001..6b91a1e909e1d 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -64,6 +64,7 @@ ctypedef fused realloc_ptr: (PriorityHeapRecord*) (void**) (INT32_t*) + (UINT32_t*) cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems, size_t elem_bytes) nogil except * @@ -82,6 +83,15 @@ cdef double rand_uniform(double low, double high, cdef double log(double x) nogil + +cdef void setup_cat_cache(UINT32_t* cachebits, UINT64_t cat_split, + INT32_t n_categories) nogil + + +cdef bint goes_left(DTYPE_t feature_value, SplitValue split, + INT32_t n_categories, UINT32_t* cachebits) nogil + + # ============================================================================= # Stack data structure # ============================================================================= diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 01e93ecc7e189..712d506d8e0ea 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -93,6 +93,46 @@ cdef inline double log(double x) nogil: return ln(x) / ln(2.0) +cdef inline void setup_cat_cache(UINT32_t *cachebits, UINT64_t cat_split, + INT32_t n_categories) nogil: + """Populate the bits of the category cache from a split. + """ + cdef INT32_t j + cdef UINT32_t rng_seed, val + + if n_categories > 0: + if cat_split & 1: + # RandomSplitter + for j in range((n_categories + 31) // 32): + cachebits[j] = 0 + rng_seed = cat_split >> 32 + for j in range(n_categories): + val = rand_int(0, 2, &rng_seed) + cachebits[j // 32] |= val << (j % 32) + else: + # BestSplitter + for j in range((n_categories + 31) // 32): + cachebits[j] = (cat_split >> (j * 32)) & 0xFFFFFFFF + + +cdef inline bint goes_left(DTYPE_t feature_value, SplitValue split, + INT32_t n_categories, UINT32_t* cachebits) nogil: + """Determine whether a sample goes to the left or right child node.""" + cdef SIZE_t idx, shift + + if n_categories < 1: + # Non-categorical feature + return feature_value <= split.threshold + else: + # Categorical feature, using bit cache + if ( feature_value) < n_categories: + idx = ( feature_value) // 32 + shift = ( feature_value) % 32 + return (cachebits[idx] >> shift) & 1 + else: + return 0 + + # ============================================================================= # Stack data structure # ============================================================================= From 2bd5633779da2e790221e20bf9ea9b1c6c9c19d3 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Sun, 8 Jan 2017 15:31:57 -0800 Subject: [PATCH 06/13] BestSplitter now calculates the best categorical split. --- sklearn/tree/_splitter.pxd | 2 + sklearn/tree/_splitter.pyx | 158 +++++++++++++++++++++++++------------ 2 files changed, 109 insertions(+), 51 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 5373e31dfa985..00e1390b797cc 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -21,6 +21,7 @@ ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer cdef struct SplitRecord: # Data to track sample split @@ -68,6 +69,7 @@ cdef class Splitter: cdef DOUBLE_t* sample_weight cdef INT32_t *n_categories # (n_features,) array giving number of # categories (<0 for non-categorical) + cdef UINT32_t* cat_cache # Cache buffer for fast categorical split evaluation # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 644c5a4d95e58..63075d3b1d5c5 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -33,6 +33,8 @@ from ._utils cimport rand_int from ._utils cimport rand_uniform from ._utils cimport RAND_R_MAX from ._utils cimport safe_realloc +from ._utils cimport setup_cat_cache +from ._utils cimport goes_left cdef double INFINITY = np.inf @@ -96,6 +98,7 @@ cdef class Splitter: self.y_stride = 0 self.sample_weight = NULL self.n_categories = NULL + self.cat_cache = NULL self.max_features = max_features self.min_samples_leaf = min_samples_leaf @@ -111,6 +114,7 @@ cdef class Splitter: free(self.constant_features) free(self.feature_values) free(self.n_categories) + free(self.cat_cache) def __getstate__(self): return {} @@ -198,6 +202,12 @@ cdef class Splitter: self.n_categories[i] = (-1 if n_categories == NULL else n_categories[i]) + # If needed, allocate cache space for categorical splits + cdef INT32_t max_n_categories = max( + [self.n_categories[i] for i in range(n_features)]) + if max_n_categories > 0: + safe_realloc(&self.cat_cache, (max_n_categories + 31) // 32, sizeof(UINT32_t)) + return 0 cdef int node_reset(self, SIZE_t start, SIZE_t end, @@ -361,7 +371,6 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t f_i = n_features cdef SIZE_t f_j - cdef SIZE_t tmp cdef SIZE_t p cdef SIZE_t feature_idx_offset cdef SIZE_t feature_offset @@ -378,6 +387,10 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t n_total_constants = n_known_constants cdef DTYPE_t current_feature_value cdef SIZE_t partition_end + cdef bint is_categorical + cdef UINT64_t cat_idx, cat_split + cdef SIZE_t ncat_present + cdef INT32_t cat_offs[64] _init_split(&best, end) @@ -419,9 +432,8 @@ cdef class BestSplitter(BaseDenseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 @@ -463,63 +475,113 @@ cdef class BestSplitter(BaseDenseSplitter): f_i -= 1 features[f_i], features[f_j] = features[f_j], features[f_i] + # Identify the number of categories present in this node + is_categorical = self.n_categories[current.feature] > 0 + if is_categorical: + cat_split = 0 + ncat_present = 0 + for i in range(start, end): + # Xf[i] < 64 already verified in tree.py + cat_split |= ( 1) << ( Xf[i]) + for i in range(self.n_categories[current.feature]): + if (cat_split >> i) & 1: + cat_offs[ncat_present] = i - ncat_present + ncat_present += 1 + # Evaluate all splits self.criterion.reset() p = start + cat_idx = 0 + + while True: + if is_categorical: + cat_idx += 1 + if cat_idx >= ( 1) << (ncat_present - 1): + break + + # Expand the bits of (2 * cat_idx) out into cat_split + # We double cat_idx to avoid double-counting equivalent splits + # This also ensures that cat_split & 1 == 0 as required + cat_split = 0 + for i in range(ncat_present): + cat_split |= ((cat_idx << 1) & (( 1) << i)) << cat_offs[i] + + # Partition + j = start + partition_end = end + while j < partition_end: + if (cat_split >> ( Xf[j])) & 1: + j += 1 + else: + partition_end -= 1 + Xf[j], Xf[partition_end] = Xf[partition_end], Xf[j] + samples[j], samples[partition_end] = ( + samples[partition_end], samples[j]) + current.pos = j + + # Must reset criterion since we've reordered the samples + self.criterion.reset() + else: + # Non-categorical feature + while (p + 1 < end and + Xf[p + 1] <= Xf[p] + FEATURE_THRESHOLD): + p += 1 - while p < end: - while (p + 1 < end and - Xf[p + 1] <= Xf[p] + FEATURE_THRESHOLD): + # (p + 1 >= end) or (X[samples[p + 1], current.feature] > + # X[samples[p], current.feature]) p += 1 + # (p >= end) or (X[samples[p], current.feature] > + # X[samples[p - 1], current.feature]) - # (p + 1 >= end) or (X[samples[p + 1], current.feature] > - # X[samples[p], current.feature]) - p += 1 - # (p >= end) or (X[samples[p], current.feature] > - # X[samples[p - 1], current.feature]) + if p >= end: + break - if p < end: current.pos = p - # Reject if min_samples_leaf is not guaranteed - if (((current.pos - start) < min_samples_leaf) or - ((end - current.pos) < min_samples_leaf)): - continue + # Reject if min_samples_leaf is not guaranteed + if (((current.pos - start) < min_samples_leaf) or + ((end - current.pos) < min_samples_leaf)): + continue - self.criterion.update(current.pos) + self.criterion.update(current.pos) - # Reject if min_weight_leaf is not satisfied - if ((self.criterion.weighted_n_left < min_weight_leaf) or - (self.criterion.weighted_n_right < min_weight_leaf)): - continue + # Reject if min_weight_leaf is not satisfied + if ((self.criterion.weighted_n_left < min_weight_leaf) or + (self.criterion.weighted_n_right < min_weight_leaf)): + continue - current_proxy_improvement = self.criterion.proxy_impurity_improvement() + current_proxy_improvement = self.criterion.proxy_impurity_improvement() - if current_proxy_improvement > best_proxy_improvement: - best_proxy_improvement = current_proxy_improvement + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + if is_categorical: + current.split_value.cat_split = cat_split + else: current.split_value.threshold = (Xf[p - 1] + Xf[p]) / 2.0 + if current.split_value.threshold == Xf[p]: + current.split_value.threshold = Xf[p - 1] - if current.split_value.threshold == Xf[p]: - current.split_value.threshold = Xf[p - 1] - - best = current # copy + best = current # copy # Reorganize into samples[start:best.pos] + samples[best.pos:end] if best.pos < end: + setup_cat_cache(self.cat_cache, best.split_value.cat_split, + self.n_categories[best.feature]) feature_offset = X_feature_stride * best.feature partition_end = end p = start while p < partition_end: - if X[X_sample_stride * samples[p] + feature_offset] <= best.split_value.threshold: + if goes_left(X[X_sample_stride * samples[p] + feature_offset], + best.split_value, self.n_categories[best.feature], + self.cat_cache): p += 1 else: partition_end -= 1 - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + samples[partition_end], samples[p] = ( + samples[p], samples[partition_end]) self.criterion.reset() self.criterion.update(best.pos) @@ -702,7 +764,6 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef SIZE_t f_i = n_features cdef SIZE_t f_j cdef SIZE_t p - cdef SIZE_t tmp cdef SIZE_t feature_stride # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 @@ -752,9 +813,8 @@ cdef class RandomSplitter(BaseDenseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 @@ -811,9 +871,8 @@ cdef class RandomSplitter(BaseDenseSplitter): Xf[p] = Xf[partition_end] Xf[partition_end] = current_feature_value - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + samples[partition_end], samples[p] = ( + samples[p], samples[partition_end]) current.pos = partition_end @@ -851,9 +910,8 @@ cdef class RandomSplitter(BaseDenseSplitter): else: partition_end -= 1 - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + samples[partition_end], samples[p] = ( + samples[p], samples[partition_end]) self.criterion.reset() @@ -1248,7 +1306,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): cdef double best_proxy_improvement = - INFINITY cdef SIZE_t f_i = n_features - cdef SIZE_t f_j, p, tmp + cdef SIZE_t f_j, p cdef SIZE_t n_visited_features = 0 # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 @@ -1303,9 +1361,8 @@ cdef class BestSparseSplitter(BaseSparseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 @@ -1480,7 +1537,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): cdef DTYPE_t current_feature_value cdef SIZE_t f_i = n_features - cdef SIZE_t f_j, p, tmp + cdef SIZE_t f_j, p cdef SIZE_t n_visited_features = 0 # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 @@ -1536,9 +1593,8 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 From e3f0a995604002d7230380b6bb30482ba0d3b23b Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Sat, 29 Oct 2016 00:39:11 -0700 Subject: [PATCH 07/13] Added categorical split code to RandomSplitter.node_split --- sklearn/tree/_splitter.pyx | 65 +++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 63075d3b1d5c5..e66aefbe8694d 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -763,7 +763,7 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef SIZE_t f_i = n_features cdef SIZE_t f_j - cdef SIZE_t p + cdef SIZE_t p, q cdef SIZE_t feature_stride # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 @@ -777,6 +777,8 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef DTYPE_t max_feature_value cdef DTYPE_t current_feature_value cdef SIZE_t partition_end + cdef bint is_categorical + cdef UINT64_t split_seed _init_split(&best, end) @@ -851,30 +853,45 @@ cdef class RandomSplitter(BaseDenseSplitter): f_i -= 1 features[f_i], features[f_j] = features[f_j], features[f_i] - # Draw a random threshold - current.split_value.threshold = rand_uniform( - min_feature_value, max_feature_value, random_state) - - if current.split_value.threshold == max_feature_value: - current.split_value.threshold = min_feature_value - - # Partition - partition_end = end - p = start - while p < partition_end: - current_feature_value = Xf[p] - if current_feature_value <= current.split_value.threshold: - p += 1 + # Repeat split & partition if split is trivial, up to 60 times + # (Can only happen with categorical features) + for q in range(60): + # Construct a random split + is_categorical = self.n_categories[current.feature] > 0 + if is_categorical: + split_seed = rand_int(0, RAND_R_MAX + 1, + random_state) + current.split_value.cat_split = (split_seed << 32) | 1 else: - partition_end -= 1 + current.split_value.threshold = rand_uniform( + min_feature_value, max_feature_value, random_state) + if current.split_value.threshold == max_feature_value: + current.split_value.threshold = min_feature_value + + # Partition + setup_cat_cache(self.cat_cache, current.split_value.cat_split, + self.n_categories[current.feature]) + partition_end = end + p = start + while p < partition_end: + current_feature_value = Xf[p] + if goes_left(current_feature_value, current.split_value, + self.n_categories[current.feature], self.cat_cache): + p += 1 + else: + partition_end -= 1 + + Xf[p] = Xf[partition_end] + Xf[partition_end] = current_feature_value - Xf[p] = Xf[partition_end] - Xf[partition_end] = current_feature_value + samples[partition_end], samples[p] = ( + samples[p], samples[partition_end]) - samples[partition_end], samples[p] = ( - samples[p], samples[partition_end]) + current.pos = partition_end - current.pos = partition_end + # Break early if the split is non-trivial + if current.pos != start and current.pos != end: + break # Reject if min_samples_leaf is not guaranteed if (((current.pos - start) < min_samples_leaf) or @@ -899,12 +916,16 @@ cdef class RandomSplitter(BaseDenseSplitter): # Reorganize into samples[start:best.pos] + samples[best.pos:end] feature_stride = X_feature_stride * best.feature if best.pos < end: + setup_cat_cache(self.cat_cache, best.split_value.cat_split, + self.n_categories[best.feature]) if current.feature != best.feature: partition_end = end p = start while p < partition_end: - if X[X_sample_stride * samples[p] + feature_stride] <= best.split_value.threshold: + if goes_left(X[X_sample_stride * samples[p] + feature_stride], + best.split_value, self.n_categories[best.feature], + self.cat_cache): p += 1 else: From 35dad263f0c8d8b6b9a1abb32a799422ac801fb2 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Sun, 8 Jan 2017 15:46:21 -0800 Subject: [PATCH 08/13] Added an implementation of the Breiman sorting shortcut for finding the best categorical split. --- sklearn/tree/_splitter.pxd | 2 + sklearn/tree/_splitter.pyx | 85 +++++++++++++++++++++++++++++++++----- sklearn/tree/tree.py | 9 +++- 3 files changed, 84 insertions(+), 12 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 00e1390b797cc..035a5dd48ee9d 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -63,6 +63,8 @@ cdef class Splitter: cdef bint presort # Whether to use presorting, only # allowed on dense data + cdef bint breiman_shortcut # Whether decision trees are allowed to use the + # Breiman shortcut for categorical features cdef DOUBLE_t* y cdef SIZE_t y_stride diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index e66aefbe8694d..222865b10e3a4 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -62,7 +62,7 @@ cdef class Splitter: def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint breiman_shortcut): """ Parameters ---------- @@ -105,6 +105,7 @@ cdef class Splitter: self.min_weight_leaf = min_weight_leaf self.random_state = random_state self.presort = presort + self.breiman_shortcut = breiman_shortcut def __dealloc__(self): """Destructor.""" @@ -277,7 +278,7 @@ cdef class BaseDenseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint breiman_shortcut): self.X = NULL self.X_sample_stride = 0 @@ -337,6 +338,49 @@ cdef class BestSplitter(BaseDenseSplitter): self.random_state, self.presort), self.__getstate__()) + + cdef void _breiman_sort_categories(self, SIZE_t start, SIZE_t end, INT32_t ncat, + SIZE_t ncat_present, const INT32_t *cat_offs, + SIZE_t *sorted_cat) nogil: + """The Breiman shortcut for finding the best split involves a + preprocessing step wherein we sort the categories by + increasing (weighted) mean of the outcome y (whether 0/1 + binary for classification or quantitative for + regression). This function implements this preprocessing step + and produces a sorted list of category values. + """ + cdef: + SIZE_t *samples = self.samples + DTYPE_t *Xf = self.feature_values + DOUBLE_t *y = self.y + SIZE_t y_stride = self.y_stride + DOUBLE_t *sample_weight = self.sample_weight + DOUBLE_t w + SIZE_t cat, localcat + SIZE_t q, partition_end + DTYPE_t sort_value[64] + DTYPE_t sort_den[64] + + for cat in range(ncat): + sort_value[cat] = 0 + sort_den[cat] = 0 + + for q in range(start, end): + cat = Xf[q] + w = sample_weight[samples[q]] if sample_weight else 1.0 + sort_value[cat] += w * (y[y_stride * samples[q]]) + sort_den[cat] += w + + for localcat in range(ncat_present): + cat = localcat + cat_offs[localcat] + if sort_den[cat] == 0: # Avoid dividing zero by zero + sort_den[cat] = 1 + sort_value[localcat] = sort_value[cat] / sort_den[cat] + sorted_cat[localcat] = cat + + sort(&sort_value[0], sorted_cat, ncat_present) + + cdef int node_split(self, double impurity, SplitRecord* split, SIZE_t* n_constant_features) nogil except -1: """Find the best split on node samples[start:end] @@ -391,6 +435,8 @@ cdef class BestSplitter(BaseDenseSplitter): cdef UINT64_t cat_idx, cat_split cdef SIZE_t ncat_present cdef INT32_t cat_offs[64] + cdef bint breiman_shortcut = self.breiman_shortcut + cdef SIZE_t sorted_cat[64] _init_split(&best, end) @@ -487,6 +533,12 @@ cdef class BestSplitter(BaseDenseSplitter): if (cat_split >> i) & 1: cat_offs[ncat_present] = i - ncat_present ncat_present += 1 + if ncat_present <= 3: + breiman_shortcut = False # No benefit for small N + if breiman_shortcut: + self._breiman_sort_categories( + start, end, self.n_categories[current.feature], + ncat_present, cat_offs, &sorted_cat[0]) # Evaluate all splits self.criterion.reset() @@ -496,15 +548,26 @@ cdef class BestSplitter(BaseDenseSplitter): while True: if is_categorical: cat_idx += 1 - if cat_idx >= ( 1) << (ncat_present - 1): - break + if breiman_shortcut: + if cat_idx >= ncat_present: + break + + cat_split = 0 + for i in range(cat_idx): + cat_split |= ( 1) << sorted_cat[i] + if cat_split & 1: + cat_split = (~cat_split) & ( + (~( 0)) >> (64 - self.n_categories[current.feature])) + else: + if cat_idx >= ( 1) << (ncat_present - 1): + break - # Expand the bits of (2 * cat_idx) out into cat_split - # We double cat_idx to avoid double-counting equivalent splits - # This also ensures that cat_split & 1 == 0 as required - cat_split = 0 - for i in range(ncat_present): - cat_split |= ((cat_idx << 1) & (( 1) << i)) << cat_offs[i] + # Expand the bits of (2 * cat_idx) out into cat_split + # We double cat_idx to avoid double-counting equivalent splits + # This also ensures that cat_split & 1 == 0 as required + cat_split = 0 + for i in range(ncat_present): + cat_split |= ((cat_idx << 1) & (( 1) << i)) << cat_offs[i] # Partition j = start @@ -970,7 +1033,7 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint breiman_shortcut): # Parent __cinit__ is automatically called self.X_data = NULL diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 7a1fdbb5585a7..f0ad84cdfd41f 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -315,6 +315,12 @@ def fit(self, X, y, sample_weight=None, check_input=True, else: criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples) + if is_classification: + breiman_shortcut = (self.n_classes_.tolist() == [2] and + (isinstance(criterion, _criterion.Gini) or + isinstance(criterion, _criterion.Entropy))) + else: + breiman_shortcut = isinstance(criterion, _criterion.MSE) SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS @@ -325,7 +331,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, min_samples_leaf, min_weight_leaf, random_state, - self.presort) + self.presort, + breiman_shortcut) self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_, n_categories) From 2312ce03f75f215ea990c66609d02a29f2c21aca Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Fri, 4 Nov 2016 09:16:18 -0700 Subject: [PATCH 09/13] Added categorical constructor parameter and error checking to BaseDecisionTree. --- sklearn/tree/tree.py | 123 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 111 insertions(+), 12 deletions(-) diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index f0ad84cdfd41f..1674d35abff9e 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -21,6 +21,7 @@ from abc import ABCMeta from abc import abstractmethod from math import ceil +import warnings import numpy as np from scipy.sparse import issparse @@ -91,7 +92,8 @@ def __init__(self, random_state, min_impurity_split, class_weight=None, - presort=False): + presort=False, + categorical="none"): self.criterion = criterion self.splitter = splitter self.max_depth = max_depth @@ -104,6 +106,16 @@ def __init__(self, self.min_impurity_split = min_impurity_split self.class_weight = class_weight self.presort = presort + self.categorical = categorical + + # Input validation for parameter categorical + if isinstance(self.categorical, str): + if categorical not in ('all', 'none'): + raise ValueError("Invalid value for categorical: {}. Allowed" + " strings are 'all' or 'none'" + "".format(categorical)) + elif len(np.shape(categorical)) != 1: + raise ValueError("Invalid shape for categorical") def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): @@ -264,6 +276,54 @@ def fit(self, X, y, sample_weight=None, check_input=True, else: sample_weight = expanded_class_weight + # Validate categorical features + if isinstance(self.categorical, str): + if self.categorical == 'none': + categorical = np.array([], dtype=np.int) + elif self.categorical == 'all': + categorical = np.arange(self.n_features_) + else: + # Should have been caught in the constructor, but just in case + raise ValueError("Invalid value for categorical: {}. Allowed" + " strings are 'all' or 'none'" + "".format(self.categorical)) + else: + categorical = np.atleast_1d(self.categorical).flatten() + if categorical.dtype == np.bool: + if categorical.size != self.n_features_: + raise ValueError("Shape of boolean parameter categorical must" + " be (n_features,)") + categorical = np.nonzero(categorical)[0] + if (np.size(categorical) > self.n_features_ or + (categorical.size > 0 and + (categorical.min() < 0 or + categorical.max() >= self.n_features_))): + raise ValueError("Invalid shape or invalid feature index for" + " parameter categorical") + if issparse(X): + if categorical.size > 0: + raise NotImplementedError("Categorical features not supported" + " with sparse inputs") + else: + if np.any(X[:, categorical].astype(np.int) < 0): + raise ValueError("Invalid training data: categorical values" + " must be non-negative.") + + # Calculate n_categories and verify they are all at least 1% populated + n_categories = np.array([np.int(X[:, i].max()) + 1 if i in categorical + else -1 for i in range(self.n_features_)], + dtype=np.int32) + n_cat_present = np.array([np.unique(X[:, i].astype(np.int)).size + if i in categorical else -1 + for i in range(self.n_features_)], + dtype=np.int32) + if np.any((n_cat_present < 0.01 * n_cat_present)[categorical]): + warnings.warn("At least one categorical feature has less than 1%" + " of its categories present in the sample. Runtime" + " and memory usage will be much smaller if you" + " represent the categories as sequential integers.", + UserWarning) + # Set min_weight_leaf from min_weight_fraction_leaf if sample_weight is None: min_weight_leaf = (self.min_weight_fraction_leaf * @@ -303,9 +363,6 @@ def fit(self, X, y, sample_weight=None, check_input=True, ".shape = {})".format(X.shape, X_idx_sorted.shape)) - # Set n_categories (hard-code -1 for now) - n_categories = np.array([-1] * self.n_features_, dtype=np.int32) - # Build tree criterion = self.criterion if not isinstance(criterion, Criterion): @@ -334,6 +391,12 @@ def fit(self, X, y, sample_weight=None, check_input=True, self.presort, breiman_shortcut) + if (not isinstance(splitter, _splitter.RandomSplitter) and + np.max(n_categories) > 64): + raise ValueError("Categorical features with greater than 64" + " categories not supported with DecisionTree;" + " try ExtraTree.") + self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_, n_categories) @@ -372,6 +435,9 @@ def _validate_X_predict(self, X, check_input): X.indptr.dtype != np.intc): raise ValueError("No support for np.int64 index based " "sparse matrices") + if issparse(X) and np.any(self.tree_.n_categories > 0): + raise NotImplementedError("Categorical features not supported" + " with sparse inputs") n_features = X.shape[1] if self.n_features_ != n_features: @@ -612,6 +678,19 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + Attributes ---------- classes_ : array of shape = [n_classes] or a list of such arrays @@ -684,7 +763,8 @@ def __init__(self, max_leaf_nodes=None, min_impurity_split=1e-7, class_weight=None, - presort=False): + presort=False, + categorical="none"): super(DecisionTreeClassifier, self).__init__( criterion=criterion, splitter=splitter, @@ -697,7 +777,8 @@ def __init__(self, class_weight=class_weight, random_state=random_state, min_impurity_split=min_impurity_split, - presort=presort) + presort=presort, + categorical=categorical) def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): @@ -913,6 +994,18 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data using the + ``MSE`` criterion. In this case, the runtime is linear in the number + of categories. Extra-random trees have an upper limit of :math:`2^{31}` + categories, and runtimes linear in the number of categories. + Attributes ---------- feature_importances_ : array of shape = [n_features] @@ -976,7 +1069,8 @@ def __init__(self, random_state=None, max_leaf_nodes=None, min_impurity_split=1e-7, - presort=False): + presort=False, + categorical="none"): super(DecisionTreeRegressor, self).__init__( criterion=criterion, splitter=splitter, @@ -988,7 +1082,8 @@ def __init__(self, max_leaf_nodes=max_leaf_nodes, random_state=random_state, min_impurity_split=min_impurity_split, - presort=presort) + presort=presort, + categorical=categorical) def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): @@ -1069,7 +1164,8 @@ def __init__(self, random_state=None, max_leaf_nodes=None, min_impurity_split=1e-7, - class_weight=None): + class_weight=None, + categorical="none"): super(ExtraTreeClassifier, self).__init__( criterion=criterion, splitter=splitter, @@ -1081,7 +1177,8 @@ def __init__(self, max_leaf_nodes=max_leaf_nodes, class_weight=class_weight, min_impurity_split=min_impurity_split, - random_state=random_state) + random_state=random_state, + categorical=categorical) class ExtraTreeRegressor(DecisionTreeRegressor): @@ -1118,7 +1215,8 @@ def __init__(self, max_features="auto", random_state=None, min_impurity_split=1e-7, - max_leaf_nodes=None): + max_leaf_nodes=None, + categorical="none"): super(ExtraTreeRegressor, self).__init__( criterion=criterion, splitter=splitter, @@ -1129,4 +1227,5 @@ def __init__(self, max_features=max_features, max_leaf_nodes=max_leaf_nodes, min_impurity_split=min_impurity_split, - random_state=random_state) + random_state=random_state, + categorical=categorical) From e0068f8622f3223b85fd00f742bb147ea3d32bee Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Tue, 22 Nov 2016 17:40:49 -0800 Subject: [PATCH 10/13] Added the categorical keyword to forest constructors. --- sklearn/ensemble/forest.py | 63 +++++++++++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 1c160be7870bc..176b623263a6d 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -813,6 +813,17 @@ class RandomForestClassifier(ForestClassifier): .. versionadded:: 0.18 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + bootstrap : boolean, optional (default=True) Whether bootstrap samples are used when building trees. @@ -908,6 +919,7 @@ def __init__(self, max_features="auto", max_leaf_nodes=None, min_impurity_split=1e-7, + categorical="none", bootstrap=True, oob_score=False, n_jobs=1, @@ -921,7 +933,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -938,6 +950,7 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split + self.categorical = categorical class RandomForestRegressor(ForestRegressor): @@ -1025,6 +1038,17 @@ class RandomForestRegressor(ForestRegressor): .. versionadded:: 0.18 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + bootstrap : boolean, optional (default=True) Whether bootstrap samples are used when building trees. @@ -1089,6 +1113,7 @@ def __init__(self, max_features="auto", max_leaf_nodes=None, min_impurity_split=1e-7, + categorical="none", bootstrap=True, oob_score=False, n_jobs=1, @@ -1101,7 +1126,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1117,6 +1142,7 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split + self.categorical = categorical class ExtraTreesClassifier(ForestClassifier): @@ -1197,6 +1223,13 @@ class ExtraTreesClassifier(ForestClassifier): .. versionadded:: 0.18 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + bootstrap : boolean, optional (default=False) Whether bootstrap samples are used when building trees. @@ -1293,6 +1326,7 @@ def __init__(self, max_features="auto", max_leaf_nodes=None, min_impurity_split=1e-7, + categorical="none", bootstrap=False, oob_score=False, n_jobs=1, @@ -1306,7 +1340,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1323,6 +1357,7 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split + self.categorical = categorical class ExtraTreesRegressor(ForestRegressor): @@ -1408,6 +1443,13 @@ class ExtraTreesRegressor(ForestRegressor): .. versionadded:: 0.18 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + bootstrap : boolean, optional (default=False) Whether bootstrap samples are used when building trees. @@ -1473,6 +1515,7 @@ def __init__(self, max_features="auto", max_leaf_nodes=None, min_impurity_split=1e-7, + categorical="none", bootstrap=False, oob_score=False, n_jobs=1, @@ -1485,7 +1528,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1501,6 +1544,7 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split + self.categorical = categorical class RandomTreesEmbedding(BaseForest): @@ -1566,6 +1610,13 @@ class RandomTreesEmbedding(BaseForest): .. versionadded:: 0.18 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + sparse_output : bool, optional (default=True) Whether or not to return a sparse CSR matrix, as default behavior, or to return a dense array compatible with dense pipeline operators. @@ -1611,6 +1662,7 @@ def __init__(self, min_weight_fraction_leaf=0., max_leaf_nodes=None, min_impurity_split=1e-7, + categorical="none", sparse_output=True, n_jobs=1, random_state=None, @@ -1622,7 +1674,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=False, oob_score=False, n_jobs=n_jobs, @@ -1639,6 +1691,7 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split self.sparse_output = sparse_output + self.categorical = categorical def _set_oob_score(self, X, y): raise NotImplementedError("OOB score not supported by tree embedding") From 1b7732603e29ae594da6b2ae74facc24cfec17e3 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Sun, 4 Dec 2016 18:33:09 -0800 Subject: [PATCH 11/13] Added some unit tests. --- sklearn/tree/tests/test_tree.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index ff662e9af414a..774391974757f 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1634,3 +1634,34 @@ def _pickle_copy(obj): assert_equal(typename, typename_) assert_equal(n_outputs, n_outputs_) assert_equal(n_samples, n_samples_) + + +def check_invalid_categorical(name): + Tree = ALL_TREES[name] + raise_on_construction = ['invalid string', [[0]]] + raise_on_fit = [[False, False, False], [1, 2], [-3], [0, 0, 1]] + for catval in raise_on_construction: + assert_raises(ValueError, Tree, categorical=catval) + for catval in raise_on_fit: + assert_raises(ValueError, Tree(categorical=catval).fit, X, y) + + +def test_invalid_categorical(): + for name in ALL_TREES: + yield check_invalid_categorical, name + + +def check_no_sparse_with_categorical(name): + X, y, X_sparse = [DATASETS['clf_small'][z] + for z in ['X', 'y', 'X_sparse']] + Tree = ALL_TREES[name] + assert_raises(NotImplementedError, Tree(categorical=[6, 10]).fit, + X_sparse, y) + assert_raises(NotImplementedError, + Tree(categorical=[6, 10]).fit(X, y).predict, X_sparse) + + +def test_no_sparse_with_categorical(): + # Currently we do not support sparse categorical features + for name in SPARSE_TREES: + yield check_no_sparse_with_categorical, name From 149b7bfb2766549ab4002068cb60886be86124a0 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Tue, 20 Dec 2016 20:59:44 -0800 Subject: [PATCH 12/13] Refactored _partial_dependence_tree a little. --- sklearn/ensemble/_gradient_boosting.pyx | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index 43bced4b46742..dd0f7b985251f 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -290,27 +290,25 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, cdef SIZE_t node_count = tree.node_count cdef SIZE_t stack_capacity = node_count * 2 - cdef Node **node_stack cdef double[::1] weight_stack = np_ones((stack_capacity,), dtype=np_float64) cdef SIZE_t stack_size = 1 cdef double left_sample_frac cdef double current_weight cdef double total_weight = 0.0 cdef Node *current_node - underlying_stack = np_zeros((stack_capacity,), dtype=np.intp) - node_stack = ( underlying_stack).data + cdef SIZE_t[::1] node_stack = np_zeros((stack_capacity,), dtype=np.intp) for i in range(X.shape[0]): # init stacks for new example stack_size = 1 - node_stack[0] = root_node + node_stack[0] = 0 weight_stack[0] = 1.0 total_weight = 0.0 while stack_size > 0: # get top node on stack stack_size -= 1 - current_node = node_stack[stack_size] + current_node = root_node + node_stack[stack_size] if current_node.left_child == TREE_LEAF: out[i] += weight_stack[stack_size] * value[current_node - root_node] * \ @@ -324,19 +322,17 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, # push left or right child on stack if X[i, feature_index] <= current_node.split_value.threshold: # left - node_stack[stack_size] = (root_node + - current_node.left_child) + node_stack[stack_size] = current_node.left_child else: # right - node_stack[stack_size] = (root_node + - current_node.right_child) + node_stack[stack_size] = current_node.right_child stack_size += 1 else: # split feature in complement set # push both children onto stack # push left child - node_stack[stack_size] = root_node + current_node.left_child + node_stack[stack_size] = current_node.left_child current_weight = weight_stack[stack_size] left_sample_frac = root_node[current_node.left_child].n_node_samples / \ current_node.n_node_samples @@ -351,7 +347,7 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, stack_size +=1 # push right child - node_stack[stack_size] = root_node + current_node.right_child + node_stack[stack_size] = current_node.right_child weight_stack[stack_size] = current_weight * \ (1.0 - left_sample_frac) stack_size +=1 From b65da1c09e9b3b1de6fd177018407e309ae36678 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Tue, 20 Dec 2016 22:58:10 -0800 Subject: [PATCH 13/13] Added categorical support to gradient boosting. --- sklearn/ensemble/_gradient_boosting.pyx | 37 +++++++++++++++++++++++-- sklearn/ensemble/gradient_boosting.py | 36 ++++++++++++++++++++---- 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index dd0f7b985251f..d7b9ed7e64445 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -20,10 +20,13 @@ from scipy.sparse import csr_matrix from sklearn.tree._tree cimport Node from sklearn.tree._tree cimport Tree +from sklearn.tree._tree cimport CategoryCacheMgr from sklearn.tree._tree cimport DTYPE_t from sklearn.tree._tree cimport SIZE_t from sklearn.tree._tree cimport INT32_t +from sklearn.tree._tree cimport UINT32_t from sklearn.tree._utils cimport safe_realloc +from sklearn.tree._utils cimport goes_left ctypedef np.int32_t int32 ctypedef np.float64_t float64 @@ -48,6 +51,8 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, Py_ssize_t K, Py_ssize_t n_samples, Py_ssize_t n_features, + INT32_t* n_categories, + UINT32_t** cachebits, float64 *out): """Predicts output for regression tree and stores it in ``out[i, k]``. @@ -82,6 +87,12 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, ``n_samples == X.shape[0]``. n_features : int The number of features; ``n_samples == X.shape[1]``. + n_categories : INT32_t pointer + Array of length n_features containing the number of categories + (for categorical features) or -1 (for non-categorical features) + cachebits : UINT32_t pointer pointer + Array of length node_count containing category cache buffers + for categorical features out : np.float64_t pointer The pointer to the data array where the predictions are stored. ``out`` is assumed to be a two-dimensional array of @@ -89,13 +100,19 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, """ cdef Py_ssize_t i cdef Node *node + cdef UINT32_t* node_cache + for i in range(n_samples): node = root_node + node_cache = cachebits[0] # While node not a leaf while node.left_child != TREE_LEAF: - if X[i * n_features + node.feature] <= node.split_value.threshold: + if goes_left(X[i * n_features + node.feature], node.split_value, + n_categories[node.feature], node_cache): + node_cache = cachebits[node.left_child] node = root_node + node.left_child else: + node_cache = cachebits[node.right_child] node = root_node + node.right_child out[i * K + k] += scale * value[node - root_node] @@ -213,6 +230,10 @@ def predict_stages(np.ndarray[object, ndim=2] estimators, for k in range(K): tree = estimators[i, k].tree_ + # Make category cache buffers for this tree's nodes + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(tree.nodes, tree.node_count, tree.n_categories) + # avoid buffer validation by casting to ndarray # and get data pointer # need brackets because of casting operator priority @@ -220,6 +241,7 @@ def predict_stages(np.ndarray[object, ndim=2] estimators, ( X).data, tree.nodes, tree.value, scale, k, K, X.shape[0], X.shape[1], + tree.n_categories, cache_mgr.bits, ( out).data) ## out += scale * tree.predict(X).reshape((X.shape[0], 1)) @@ -297,11 +319,19 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, cdef double total_weight = 0.0 cdef Node *current_node cdef SIZE_t[::1] node_stack = np_zeros((stack_capacity,), dtype=np.intp) + cdef UINT32_t** cachebits + cdef UINT32_t* node_cache + + # Make category cache buffers for this tree's nodes + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(root_node, node_count, tree.n_categories) + cachebits = cache_mgr.bits for i in range(X.shape[0]): # init stacks for new example stack_size = 1 node_stack[0] = 0 + node_cache = cachebits[0] weight_stack[0] = 1.0 total_weight = 0.0 @@ -309,6 +339,7 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, # get top node on stack stack_size -= 1 current_node = root_node + node_stack[stack_size] + node_cache = cachebits[node_stack[stack_size]] if current_node.left_child == TREE_LEAF: out[i] += weight_stack[stack_size] * value[current_node - root_node] * \ @@ -320,7 +351,9 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, if feature_index != -1: # split feature in target set # push left or right child on stack - if X[i, feature_index] <= current_node.split_value.threshold: + if goes_left(X[i, feature_index], current_node.split_value, + tree.n_categories[current_node.feature], + node_cache): # left node_stack[stack_size] = current_node.left_child else: diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 49a3fd1a9e348..e8c05f7405ead 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -722,7 +722,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, min_samples_split, min_samples_leaf, min_weight_fraction_leaf, max_depth, min_impurity_split, init, subsample, max_features, random_state, alpha=0.9, verbose=0, max_leaf_nodes=None, - warm_start=False, presort='auto'): + warm_start=False, presort='auto', categorical='none'): self.n_estimators = n_estimators self.learning_rate = learning_rate @@ -742,6 +742,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, self.max_leaf_nodes = max_leaf_nodes self.warm_start = warm_start self.presort = presort + self.categorical = categorical def _fit_stage(self, i, X, y, y_pred, sample_weight, sample_mask, random_state, X_idx_sorted, X_csc=None, X_csr=None): @@ -770,7 +771,8 @@ def _fit_stage(self, i, X, y, y_pred, sample_weight, sample_mask, max_features=self.max_features, max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, - presort=self.presort) + presort=self.presort, + categorical=self.categorical) if self.subsample < 1.0: # no inplace multiplication! @@ -1357,6 +1359,17 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): .. versionadded:: 0.17 *presort* parameter. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + Attributes ---------- feature_importances_ : array, shape = [n_features] @@ -1409,7 +1422,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, max_depth=3, min_impurity_split=1e-7, init=None, random_state=None, max_features=None, verbose=0, max_leaf_nodes=None, warm_start=False, - presort='auto'): + presort='auto', categorical='none'): super(GradientBoostingClassifier, self).__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -1422,7 +1435,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, max_leaf_nodes=max_leaf_nodes, min_impurity_split=min_impurity_split, warm_start=warm_start, - presort=presort) + presort=presort, categorical=categorical) def _validate_y(self, y): check_classification_targets(y) @@ -1744,6 +1757,17 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): .. versionadded:: 0.17 optional parameter *presort*. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + Attributes ---------- feature_importances_ : array, shape = [n_features] @@ -1792,7 +1816,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, min_samples_leaf=1, min_weight_fraction_leaf=0., max_depth=3, min_impurity_split=1e-7, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, - warm_start=False, presort='auto'): + warm_start=False, presort='auto', categorical='none'): super(GradientBoostingRegressor, self).__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -1803,7 +1827,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, max_features=max_features, min_impurity_split=min_impurity_split, random_state=random_state, alpha=alpha, verbose=verbose, max_leaf_nodes=max_leaf_nodes, warm_start=warm_start, - presort=presort) + presort=presort, categorical=categorical) def predict(self, X): """Predict regression target for X.