From 61e28ec5eaa02d9f38e62cc3e9e6732f0ebca183 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Thu, 22 Dec 2016 00:09:23 -0800 Subject: [PATCH 01/54] Remove inverted dependence of _utils.pyx on _tree.pyx because it was causing all kinds of problems. Now safe_realloc requires the item size to be explicitly provided. Also, it can allocate arrays of pointers to any type by casting to void*. --- sklearn/ensemble/_gradient_boosting.pyx | 8 ++++---- sklearn/tree/_criterion.pyx | 4 ++-- sklearn/tree/_splitter.pyx | 14 +++++++------- sklearn/tree/_tree.pyx | 14 +++++++------- sklearn/tree/_utils.pxd | 8 ++++---- sklearn/tree/_utils.pyx | 22 +++++++++++----------- 6 files changed, 35 insertions(+), 35 deletions(-) diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index 71371f5c24a48..63a01f63fc851 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -130,8 +130,8 @@ def _predict_regression_tree_stages_sparse(np.ndarray[object, ndim=2] estimators cdef Tree tree cdef Node** nodes = NULL cdef double** values = NULL - safe_realloc(&nodes, n_stages * n_outputs) - safe_realloc(&values, n_stages * n_outputs) + safe_realloc(&nodes, n_stages * n_outputs, sizeof(void*)) + safe_realloc(&values, n_stages * n_outputs, sizeof(void*)) for stage_i in range(n_stages): for output_i in range(n_outputs): tree = estimators[stage_i, output_i].tree_ @@ -147,8 +147,8 @@ def _predict_regression_tree_stages_sparse(np.ndarray[object, ndim=2] estimators # which features are nonzero in the present sample. cdef SIZE_t* feature_to_sample = NULL - safe_realloc(&X_sample, n_features) - safe_realloc(&feature_to_sample, n_features) + safe_realloc(&X_sample, n_features, sizeof(DTYPE_t)) + safe_realloc(&feature_to_sample, n_features, sizeof(SIZE_t)) memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 5187a5066bb2e..5fb1c459e34c0 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -249,7 +249,7 @@ cdef class ClassificationCriterion(Criterion): self.sum_right = NULL self.n_classes = NULL - safe_realloc(&self.n_classes, n_outputs) + safe_realloc(&self.n_classes, n_outputs, sizeof(SIZE_t)) cdef SIZE_t k = 0 cdef SIZE_t sum_stride = 0 @@ -1040,7 +1040,7 @@ cdef class MAE(RegressionCriterion): self.node_medians = NULL # Allocate memory for the accumulators - safe_realloc(&self.node_medians, n_outputs) + safe_realloc(&self.node_medians, n_outputs, sizeof(DOUBLE_t)) self.left_child = np.empty(n_outputs, dtype='object') self.right_child = np.empty(n_outputs, dtype='object') diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 06dfab587493c..cc40343e00d66 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -147,7 +147,7 @@ cdef class Splitter: # Create a new array which will be used to store nonzero # samples from the feature of interest - cdef SIZE_t* samples = safe_realloc(&self.samples, n_samples) + cdef SIZE_t* samples = safe_realloc(&self.samples, n_samples, sizeof(SIZE_t)) cdef SIZE_t i, j cdef double weighted_n_samples = 0.0 @@ -169,15 +169,15 @@ cdef class Splitter: self.weighted_n_samples = weighted_n_samples cdef SIZE_t n_features = X.shape[1] - cdef SIZE_t* features = safe_realloc(&self.features, n_features) + cdef SIZE_t* features = safe_realloc(&self.features, n_features, sizeof(SIZE_t)) for i in range(n_features): features[i] = i self.n_features = n_features - safe_realloc(&self.feature_values, n_samples) - safe_realloc(&self.constant_features, n_features) + safe_realloc(&self.feature_values, n_samples, sizeof(DTYPE_t)) + safe_realloc(&self.constant_features, n_features, sizeof(SIZE_t)) self.y = y.data self.y_stride = y.strides[0] / y.itemsize @@ -295,7 +295,7 @@ cdef class BaseDenseSplitter(Splitter): self.X_idx_sorted.itemsize) self.n_total_samples = X.shape[0] - safe_realloc(&self.sample_mask, self.n_total_samples) + safe_realloc(&self.sample_mask, self.n_total_samples, sizeof(SIZE_t)) memset(self.sample_mask, 0, self.n_total_samples*sizeof(SIZE_t)) return 0 @@ -924,8 +924,8 @@ cdef class BaseSparseSplitter(Splitter): self.n_total_samples = n_total_samples # Initialize auxiliary array used to perform split - safe_realloc(&self.index_to_samples, n_total_samples) - safe_realloc(&self.sorted_samples, n_samples) + safe_realloc(&self.index_to_samples, n_total_samples, sizeof(SIZE_t)) + safe_realloc(&self.sorted_samples, n_samples, sizeof(SIZE_t)) cdef SIZE_t* index_to_samples = self.index_to_samples cdef SIZE_t p diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 69ab8572d2ae5..fd2db0a814c47 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -590,7 +590,7 @@ cdef class Tree: self.n_features = n_features self.n_outputs = n_outputs self.n_classes = NULL - safe_realloc(&self.n_classes, n_outputs) + safe_realloc(&self.n_classes, n_outputs, sizeof(SIZE_t)) self.max_n_classes = np.max(n_classes) self.value_stride = n_outputs * self.max_n_classes @@ -688,8 +688,8 @@ cdef class Tree: else: capacity = 2 * self.capacity - safe_realloc(&self.nodes, capacity) - safe_realloc(&self.value, capacity * self.value_stride) + safe_realloc(&self.nodes, capacity, sizeof(Node)) + safe_realloc(&self.value, capacity * self.value_stride, sizeof(double)) # value memory is initialised to 0 to enable classifier argmax if capacity > self.capacity: @@ -843,8 +843,8 @@ cdef class Tree: # which features are nonzero in the present sample. cdef SIZE_t* feature_to_sample = NULL - safe_realloc(&X_sample, n_features) - safe_realloc(&feature_to_sample, n_features) + safe_realloc(&X_sample, n_features, sizeof(DTYPE_t)) + safe_realloc(&feature_to_sample, n_features, sizeof(SIZE_t)) with nogil: memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) @@ -989,8 +989,8 @@ cdef class Tree: # which features are nonzero in the present sample. cdef SIZE_t* feature_to_sample = NULL - safe_realloc(&X_sample, n_features) - safe_realloc(&feature_to_sample, n_features) + safe_realloc(&X_sample, n_features, sizeof(DTYPE_t)) + safe_realloc(&feature_to_sample, n_features, sizeof(SIZE_t)) with nogil: memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 017888ab41db7..c7f78f8ff2253 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -10,7 +10,6 @@ import numpy as np cimport numpy as np -from _tree cimport Node ctypedef np.npy_float32 DTYPE_t # Type of X ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight @@ -18,6 +17,8 @@ ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +cdef struct Node # Forward declaration + cdef enum: # Max value for our rand_r replacement (near the bottom). # We don't use RAND_MAX because it's different across platforms and @@ -37,13 +38,12 @@ ctypedef fused realloc_ptr: (unsigned char*) (WeightedPQueueRecord*) (DOUBLE_t*) - (DOUBLE_t**) (Node*) - (Node**) (StackRecord*) (PriorityHeapRecord*) + (void**) -cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) nogil except * +cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems, size_t elem_bytes) nogil except * cdef np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size) diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index faf2e5b777448..f517a317f1743 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -24,15 +24,15 @@ np.import_array() # Helper functions # ============================================================================= -cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) nogil except *: +cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems, size_t nbytes_elem) nogil except *: # sizeof(realloc_ptr[0]) would be more like idiomatic C, but causes Cython # 0.20.1 to crash. - cdef size_t nbytes = nelems * sizeof(p[0][0]) - if nbytes / sizeof(p[0][0]) != nelems: + cdef size_t nbytes = nelems * nbytes_elem + if nbytes / nbytes_elem != nelems: # Overflow in the multiplication with gil: raise MemoryError("could not allocate (%d * %d) bytes" - % (nelems, sizeof(p[0][0]))) + % (nelems, nbytes_elem)) cdef realloc_ptr tmp = realloc(p[0], nbytes) if tmp == NULL: with gil: @@ -46,7 +46,7 @@ def _realloc_test(): # Helper for tests. Tries to allocate (-1) / 2 * sizeof(size_t) # bytes, which will always overflow. cdef SIZE_t* p = NULL - safe_realloc(&p, (-1) / 2) + safe_realloc(&p, (-1) / 2, sizeof(SIZE_t)) if p != NULL: free(p) assert False @@ -132,7 +132,7 @@ cdef class Stack: if top >= self.capacity: self.capacity *= 2 # Since safe_realloc can raise MemoryError, use `except -1` - safe_realloc(&self.stack_, self.capacity) + safe_realloc(&self.stack_, self.capacity, sizeof(StackRecord)) stack = self.stack_ stack[top].start = start @@ -192,7 +192,7 @@ cdef class PriorityHeap: def __cinit__(self, SIZE_t capacity): self.capacity = capacity self.heap_ptr = 0 - safe_realloc(&self.heap_, capacity) + safe_realloc(&self.heap_, capacity, sizeof(PriorityHeapRecord)) def __dealloc__(self): free(self.heap_) @@ -248,7 +248,7 @@ cdef class PriorityHeap: if heap_ptr >= self.capacity: self.capacity *= 2 # Since safe_realloc can raise MemoryError, use `except -1` - safe_realloc(&self.heap_, self.capacity) + safe_realloc(&self.heap_, self.capacity, sizeof(PriorityHeapRecord)) # Put element as last element of heap heap = self.heap_ @@ -318,7 +318,7 @@ cdef class WeightedPQueue: def __cinit__(self, SIZE_t capacity): self.capacity = capacity self.array_ptr = 0 - safe_realloc(&self.array_, capacity) + safe_realloc(&self.array_, capacity, sizeof(WeightedPQueueRecord)) def __dealloc__(self): free(self.array_) @@ -331,7 +331,7 @@ cdef class WeightedPQueue: """ self.array_ptr = 0 # Since safe_realloc can raise MemoryError, use `except *` - safe_realloc(&self.array_, self.capacity) + safe_realloc(&self.array_, self.capacity, sizeof(WeightedPQueueRecord)) return 0 cdef bint is_empty(self) nogil: @@ -354,7 +354,7 @@ cdef class WeightedPQueue: if array_ptr >= self.capacity: self.capacity *= 2 # Since safe_realloc can raise MemoryError, use `except -1` - safe_realloc(&self.array_, self.capacity) + safe_realloc(&self.array_, self.capacity, sizeof(WeightedPQueueRecord)) # Put element as last element of array array = self.array_ From 926d48f758f1b1437cf1d8849e13feccc9b4c002 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Wed, 5 Oct 2016 22:59:15 -0700 Subject: [PATCH 02/54] Tree constructor now checks for mismatched struct sizes. --- sklearn/tree/_tree.pyx | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index fd2db0a814c47..9151ae8cb3437 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -599,6 +599,14 @@ cdef class Tree: for k in range(n_outputs): self.n_classes[k] = n_classes[k] + # Ensure cython and numpy node sizes match up + np_node_size = ( NODE_DTYPE).itemsize + node_size = sizeof(Node) + if (np_node_size != node_size): + raise TypeError('Size of numpy NODE_DTYPE ({} bytes) does not' + ' match size of Node ({} bytes)'.format( + np_node_size, node_size)) + # Inner structures self.max_depth = 0 self.node_count = 0 From 82e932e4dcb31f1e6ffe9890fef9d4ada997d60a Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Fri, 7 Oct 2016 23:03:44 -0700 Subject: [PATCH 03/54] Created SplitValue datatype to generalize the concept of a threshold to categorical variables. Replaced the threshold attribute of SplitRecord and Node with SplitValue. --- sklearn/ensemble/_gradient_boosting.pyx | 6 ++-- sklearn/tree/_splitter.pxd | 5 ++- sklearn/tree/_splitter.pyx | 44 ++++++++++++------------- sklearn/tree/_tree.pxd | 5 ++- sklearn/tree/_tree.pyx | 29 +++++++++------- sklearn/tree/_utils.pxd | 21 ++++++++++++ 6 files changed, 71 insertions(+), 39 deletions(-) diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index 63a01f63fc851..43bced4b46742 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -93,7 +93,7 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, node = root_node # While node not a leaf while node.left_child != TREE_LEAF: - if X[i * n_features + node.feature] <= node.threshold: + if X[i * n_features + node.feature] <= node.split_value.threshold: node = root_node + node.left_child else: node = root_node + node.right_child @@ -174,7 +174,7 @@ def _predict_regression_tree_stages_sparse(np.ndarray[object, ndim=2] estimators else: feature_value = 0. - if feature_value <= node.threshold: + if feature_value <= node.split_value.threshold: node = root_node + node.left_child else: node = root_node + node.right_child @@ -322,7 +322,7 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, if feature_index != -1: # split feature in target set # push left or right child on stack - if X[i, feature_index] <= current_node.threshold: + if X[i, feature_index] <= current_node.split_value.threshold: # left node_stack[stack_size] = (root_node + current_node.left_child) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 4d5c5ae46bceb..8c3d3f47f63c1 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -12,6 +12,8 @@ import numpy as np cimport numpy as np +from ._utils cimport SplitValue + from ._criterion cimport Criterion ctypedef np.npy_float32 DTYPE_t # Type of X @@ -26,7 +28,8 @@ cdef struct SplitRecord: SIZE_t pos # Split samples array at the given position, # i.e. count of samples below threshold for feature. # pos is >= end if the node is a leaf. - double threshold # Threshold to split at. + SplitValue split_value # Generalized threshold for categorical and + # non-categorical features double improvement # Impurity improvement given parent node. double impurity_left # Impurity of the left split. double impurity_right # Impurity of the right split. diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index cc40343e00d66..9f119079a7d98 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -48,7 +48,7 @@ cdef inline void _init_split(SplitRecord* self, SIZE_t start_pos) nogil: self.impurity_right = INFINITY self.pos = start_pos self.feature = 0 - self.threshold = 0. + self.split_value.threshold = 0. self.improvement = -INFINITY cdef class Splitter: @@ -481,10 +481,10 @@ cdef class BestSplitter(BaseDenseSplitter): if current_proxy_improvement > best_proxy_improvement: best_proxy_improvement = current_proxy_improvement - current.threshold = (Xf[p - 1] + Xf[p]) / 2.0 + current.split_value.threshold = (Xf[p - 1] + Xf[p]) / 2.0 - if current.threshold == Xf[p]: - current.threshold = Xf[p - 1] + if current.split_value.threshold == Xf[p]: + current.split_value.threshold = Xf[p - 1] best = current # copy @@ -495,7 +495,7 @@ cdef class BestSplitter(BaseDenseSplitter): p = start while p < partition_end: - if X[X_sample_stride * samples[p] + feature_offset] <= best.threshold: + if X[X_sample_stride * samples[p] + feature_offset] <= best.split_value.threshold: p += 1 else: @@ -776,19 +776,18 @@ cdef class RandomSplitter(BaseDenseSplitter): features[f_i], features[f_j] = features[f_j], features[f_i] # Draw a random threshold - current.threshold = rand_uniform(min_feature_value, - max_feature_value, - random_state) + current.split_value.threshold = rand_uniform( + min_feature_value, max_feature_value, random_state) - if current.threshold == max_feature_value: - current.threshold = min_feature_value + if current.split_value.threshold == max_feature_value: + current.split_value.threshold = min_feature_value # Partition partition_end = end p = start while p < partition_end: current_feature_value = Xf[p] - if current_feature_value <= current.threshold: + if current_feature_value <= current.split_value.threshold: p += 1 else: partition_end -= 1 @@ -830,7 +829,7 @@ cdef class RandomSplitter(BaseDenseSplitter): p = start while p < partition_end: - if X[X_sample_stride * samples[p] + feature_stride] <= best.threshold: + if X[X_sample_stride * samples[p] + feature_stride] <= best.split_value.threshold: p += 1 else: @@ -1381,9 +1380,9 @@ cdef class BestSparseSplitter(BaseSparseSplitter): if current_proxy_improvement > best_proxy_improvement: best_proxy_improvement = current_proxy_improvement - current.threshold = (Xf[p_prev] + Xf[p]) / 2.0 - if current.threshold == Xf[p]: - current.threshold = Xf[p_prev] + current.split_value.threshold = (Xf[p_prev] + Xf[p]) / 2.0 + if current.split_value.threshold == Xf[p]: + current.split_value.threshold = Xf[p_prev] best = current @@ -1392,7 +1391,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): self.extract_nnz(best.feature, &end_negative, &start_positive, &is_samples_sorted) - self._partition(best.threshold, end_negative, start_positive, + self._partition(best.split_value.threshold, end_negative, start_positive, best.pos) self.criterion.reset() @@ -1579,15 +1578,14 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): features[f_i], features[f_j] = features[f_j], features[f_i] # Draw a random threshold - current.threshold = rand_uniform(min_feature_value, - max_feature_value, - random_state) + current.split_value.threshold = rand_uniform( + min_feature_value, max_feature_value, random_state) - if current.threshold == max_feature_value: - current.threshold = min_feature_value + if current.split_value.threshold == max_feature_value: + current.split_value.threshold = min_feature_value # Partition - current.pos = self._partition(current.threshold, + current.pos = self._partition(current.split_value.threshold, end_negative, start_positive, start_positive + @@ -1623,7 +1621,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): self.extract_nnz(best.feature, &end_negative, &start_positive, &is_samples_sorted) - self._partition(best.threshold, end_negative, start_positive, + self._partition(best.split_value.threshold, end_negative, start_positive, best.pos) self.criterion.reset() diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 4f9f359725646..e3febde5e4d76 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -19,6 +19,8 @@ ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +from ._utils cimport SplitValue + from ._splitter cimport Splitter from ._splitter cimport SplitRecord @@ -28,7 +30,8 @@ cdef struct Node: SIZE_t left_child # id of the left child of the node SIZE_t right_child # id of the right child of the node SIZE_t feature # Feature used for splitting the node - DOUBLE_t threshold # Threshold value at the node + SplitValue split_value # Generalized threshold for categorical and + # non-categorical features DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) SIZE_t n_node_samples # Number of samples at the node DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 9151ae8cb3437..812ae5ec805eb 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -65,6 +65,14 @@ cdef SIZE_t _TREE_UNDEFINED = TREE_UNDEFINED cdef SIZE_t INITIAL_STACK_SIZE = 10 # Repeat struct definition for numpy +# NOTE: SPLITVALUE_DTYPE cannot be used with numpy < v1.7 +# Recommend replacing threshold with it when support for v1.6 is dropped + +##SPLITVALUE_DTYPE = np.dtype({ +## 'names': ['threshold', 'cat_split'], +## 'formats': [np.float64, np.uint64], +## 'offsets': [0, 0] +##}) NODE_DTYPE = np.dtype({ 'names': ['left_child', 'right_child', 'feature', 'threshold', 'impurity', 'n_node_samples', 'weighted_n_node_samples'], @@ -74,7 +82,7 @@ NODE_DTYPE = np.dtype({ &( NULL).left_child, &( NULL).right_child, &( NULL).feature, - &( NULL).threshold, + &( NULL).split_value, &( NULL).impurity, &( NULL).n_node_samples, &( NULL).weighted_n_node_samples @@ -182,7 +190,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef SplitRecord split cdef SIZE_t node_id - cdef double threshold cdef double impurity = INFINITY cdef SIZE_t n_constant_features cdef bint is_leaf @@ -232,7 +239,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): is_leaf = is_leaf or (split.pos >= end) node_id = tree._add_node(parent, is_left, is_leaf, split.feature, - split.threshold, impurity, n_node_samples, + split.split_value.threshold, impurity, n_node_samples, weighted_n_node_samples) if node_id == (-1): @@ -363,7 +370,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): node.left_child = _TREE_LEAF node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED - node.threshold = _TREE_UNDEFINED + node.split_value.threshold = _TREE_UNDEFINED else: # Node is expandable @@ -452,7 +459,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if parent != NULL else _TREE_UNDEFINED, is_left, is_leaf, - split.feature, split.threshold, impurity, n_node_samples, + split.feature, split.split_value.threshold, impurity, n_node_samples, weighted_n_node_samples) if node_id == (-1): return -1 @@ -743,12 +750,12 @@ cdef class Tree: node.left_child = _TREE_LEAF node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED - node.threshold = _TREE_UNDEFINED + node.split_value.threshold = _TREE_UNDEFINED else: # left_child and right_child will be set later node.feature = feature - node.threshold = threshold + node.split_value.threshold = threshold self.node_count += 1 @@ -802,7 +809,7 @@ cdef class Tree: while node.left_child != _TREE_LEAF: # ... and node.right_child != _TREE_LEAF: if X_ptr[X_sample_stride * i + - X_fx_stride * node.feature] <= node.threshold: + X_fx_stride * node.feature] <= node.split_value.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -873,7 +880,7 @@ cdef class Tree: else: feature_value = 0. - if feature_value <= node.threshold: + if feature_value <= node.split_value.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -936,7 +943,7 @@ cdef class Tree: indptr_ptr[i + 1] += 1 if X_ptr[X_sample_stride * i + - X_fx_stride * node.feature] <= node.threshold: + X_fx_stride * node.feature] <= node.split_value.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -1024,7 +1031,7 @@ cdef class Tree: else: feature_value = 0. - if feature_value <= node.threshold: + if feature_value <= node.split_value.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index c7f78f8ff2253..214b004f3487e 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -16,6 +16,27 @@ ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer + +ctypedef union SplitValue: + # Union type to generalize the concept of a threshold to + # categorical features. For non-categorical features, use the + # threshold member. It acts just as before, where feature values + # less than or equal to the threshold go left, and values greater + # than the threshold go right. + # + # For categorical features, use the cat_split member. It works in + # one of two ways, indicated by the value of its least significant + # bit (LSB). If the LSB is 0, then cat_split acts as a bitfield + # for up to 64 categories, sending samples left if the bit + # corresponding to their category is 1 or right if it is 0. If the + # LSB is 1, then the more significant 32 bits of cat_split is a + # random seed. To evaluate a sample, use the random seed to flip a + # coin (category_value + 1) times and send it left if the last + # flip gives 1; otherwise right. This second method allows up to + # 2**31 category values, but can only be used for RandomSplitter. + DOUBLE_t threshold + UINT64_t cat_split cdef struct Node # Forward declaration From 5cfa6c2ad303c790cffda55270386e9ce710d09d Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Sat, 8 Oct 2016 14:22:46 -0700 Subject: [PATCH 04/54] Added attribute n_categories to Splitter and Tree, an array of ints that defaults to -1 for each feature (indicating non-categorical). --- sklearn/tree/_splitter.pxd | 3 +++ sklearn/tree/_splitter.pyx | 21 +++++++++++++++++++-- sklearn/tree/_tree.pxd | 3 +++ sklearn/tree/_tree.pyx | 37 +++++++++++++++++++++++++++++++++---- sklearn/tree/_utils.pxd | 2 ++ sklearn/tree/_utils.pyx | 7 +++++++ sklearn/tree/tree.py | 9 +++++++-- 7 files changed, 74 insertions(+), 8 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 8c3d3f47f63c1..5373e31dfa985 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -66,6 +66,8 @@ cdef class Splitter: cdef DOUBLE_t* y cdef SIZE_t y_stride cdef DOUBLE_t* sample_weight + cdef INT32_t *n_categories # (n_features,) array giving number of + # categories (<0 for non-categorical) # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, @@ -86,6 +88,7 @@ cdef class Splitter: # Methods cdef int init(self, object X, np.ndarray y, DOUBLE_t* sample_weight, + INT32_t* n_categories, np.ndarray X_idx_sorted=*) except -1 cdef int node_reset(self, SIZE_t start, SIZE_t end, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 9f119079a7d98..644c5a4d95e58 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -95,6 +95,7 @@ cdef class Splitter: self.y = NULL self.y_stride = 0 self.sample_weight = NULL + self.n_categories = NULL self.max_features = max_features self.min_samples_leaf = min_samples_leaf @@ -109,6 +110,7 @@ cdef class Splitter: free(self.features) free(self.constant_features) free(self.feature_values) + free(self.n_categories) def __getstate__(self): return {} @@ -120,6 +122,7 @@ cdef class Splitter: object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, + INT32_t* n_categories, np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter. @@ -140,6 +143,10 @@ cdef class Splitter: The weights of the samples, where higher weighted samples are fit closer than lower weight samples. If not provided, all samples are assumed to have uniform weight. + + n_categories : array of INT32_t, shape=(n_features,) + Number of categories for categorical features, or -1 for + non-categorical features """ self.rand_r_state = self.random_state.randint(0, RAND_R_MAX) @@ -183,6 +190,14 @@ cdef class Splitter: self.y_stride = y.strides[0] / y.itemsize self.sample_weight = sample_weight + + # Initialize the number of categories for each feature + # A value of -1 indicates a non-categorical feature + safe_realloc(&self.n_categories, n_features, sizeof(INT32_t)) + for i in range(n_features): + self.n_categories[i] = (-1 if n_categories == NULL + else n_categories[i]) + return 0 cdef int node_reset(self, SIZE_t start, SIZE_t end, @@ -271,6 +286,7 @@ cdef class BaseDenseSplitter(Splitter): object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, + INT32_t* n_categories, np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter @@ -279,7 +295,7 @@ cdef class BaseDenseSplitter(Splitter): """ # Call parent init - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X, y, sample_weight, n_categories) # Initialize X cdef np.ndarray X_ndarray = X @@ -896,6 +912,7 @@ cdef class BaseSparseSplitter(Splitter): object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, + INT32_t* n_categories, np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter @@ -903,7 +920,7 @@ cdef class BaseSparseSplitter(Splitter): or 0 otherwise. """ # Call parent init - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X, y, sample_weight, n_categories) if not isinstance(X, csc_matrix): raise ValueError("X should be in csc format") diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index e3febde5e4d76..75f2775ab90e8 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -56,6 +56,8 @@ cdef class Tree: cdef Node* nodes # Array of nodes cdef double* value # (capacity, n_outputs, max_n_classes) array of values cdef SIZE_t value_stride # = n_outputs * max_n_classes + cdef INT32_t *n_categories # (n_features,) array giving number of + # categories (<0 for non-categorical) # Methods cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, @@ -103,5 +105,6 @@ cdef class TreeBuilder: cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=*, + np.ndarray n_categories=*, np.ndarray X_idx_sorted=*) cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 812ae5ec805eb..268ad422bc508 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -36,6 +36,7 @@ from ._utils cimport PriorityHeap from ._utils cimport PriorityHeapRecord from ._utils cimport safe_realloc from ._utils cimport sizet_ptr_to_ndarray +from ._utils cimport int32_ptr_to_ndarray cdef extern from "numpy/arrayobject.h": object PyArray_NewFromDescr(object subtype, np.dtype descr, @@ -98,6 +99,7 @@ cdef class TreeBuilder: cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, + np.ndarray n_categories=None, np.ndarray X_idx_sorted=None): """Build a decision tree from the training set (X, y).""" pass @@ -148,6 +150,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, + np.ndarray n_categories=None, np.ndarray X_idx_sorted=None): """Build a decision tree from the training set (X, y).""" @@ -158,6 +161,11 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if sample_weight is not None: sample_weight_ptr = sample_weight.data + cdef INT32_t *n_categories_ptr = NULL + if n_categories is not None: + n_categories = np.asarray(n_categories, dtype=np.int32, order='C') + n_categories_ptr = n_categories.data + # Initial capacity cdef int init_capacity @@ -177,7 +185,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double min_impurity_split = self.min_impurity_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr, X_idx_sorted) + splitter.init(X, y, sample_weight_ptr, n_categories_ptr, X_idx_sorted) cdef SIZE_t start cdef SIZE_t end @@ -311,6 +319,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, + np.ndarray n_categories=None, np.ndarray X_idx_sorted=None): """Build a decision tree from the training set (X, y).""" @@ -321,6 +330,11 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if sample_weight is not None: sample_weight_ptr = sample_weight.data + cdef INT32_t *n_categories_ptr = NULL + if n_categories is not None: + n_categories = np.asarray(n_categories, dtype=np.int32, order='C') + n_categories_ptr = n_categories.data + # Parameters cdef Splitter splitter = self.splitter cdef SIZE_t max_leaf_nodes = self.max_leaf_nodes @@ -329,7 +343,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_split = self.min_samples_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr, X_idx_sorted) + splitter.init(X, y, sample_weight_ptr, n_categories_ptr, X_idx_sorted) cdef PriorityHeap frontier = PriorityHeap(INITIAL_STACK_SIZE) cdef PriorityHeapRecord record @@ -539,6 +553,10 @@ cdef class Tree: value : array of double, shape [node_count, n_outputs, max_n_classes] Contains the constant prediction value of each node. + n_categories : array of int32, shape [n_features] + Number of expected category values for categorical features, or + -1 for non-categorical features. + impurity : array of double, shape [node_count] impurity[i] holds the impurity (i.e., the value of the splitting criterion) at node i. @@ -590,14 +608,20 @@ cdef class Tree: def __get__(self): return self._get_value_ndarray()[:self.node_count] + property n_categories: + def __get__(self): + return int32_ptr_to_ndarray(self.n_categories, self.n_features).copy() + def __cinit__(self, int n_features, np.ndarray[SIZE_t, ndim=1] n_classes, - int n_outputs): + int n_outputs, np.ndarray[INT32_t, ndim=1] n_categories): """Constructor.""" # Input/Output layout self.n_features = n_features self.n_outputs = n_outputs self.n_classes = NULL + self.n_categories = NULL safe_realloc(&self.n_classes, n_outputs, sizeof(SIZE_t)) + safe_realloc(&self.n_categories, n_features, sizeof(INT32_t)) self.max_n_classes = np.max(n_classes) self.value_stride = n_outputs * self.max_n_classes @@ -605,6 +629,8 @@ cdef class Tree: cdef SIZE_t k for k in range(n_outputs): self.n_classes[k] = n_classes[k] + for k in range(n_features): + self.n_categories[k] = n_categories[k] # Ensure cython and numpy node sizes match up np_node_size = ( NODE_DTYPE).itemsize @@ -627,12 +653,15 @@ cdef class Tree: free(self.n_classes) free(self.value) free(self.nodes) + free(self.n_categories) def __reduce__(self): """Reduce re-implementation, for pickling.""" return (Tree, (self.n_features, sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), - self.n_outputs), self.__getstate__()) + self.n_outputs, + int32_ptr_to_ndarray(self.n_categories, self.n_features)), + self.__getstate__()) def __getstate__(self): """Getstate re-implementation, for pickling.""" diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 214b004f3487e..07ea94f03e001 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -63,12 +63,14 @@ ctypedef fused realloc_ptr: (StackRecord*) (PriorityHeapRecord*) (void**) + (INT32_t*) cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems, size_t elem_bytes) nogil except * cdef np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size) +cdef np.ndarray int32_ptr_to_ndarray(INT32_t* data, SIZE_t size) cdef SIZE_t rand_int(SIZE_t low, SIZE_t high, UINT32_t* random_state) nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index f517a317f1743..01e93ecc7e189 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -69,6 +69,13 @@ cdef inline np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size): return np.PyArray_SimpleNewFromData(1, shape, np.NPY_INTP, data).copy() +cdef inline np.ndarray int32_ptr_to_ndarray(INT32_t* data, SIZE_t size): + """Encapsulate data into a 1D numpy array of int32's.""" + cdef np.npy_intp shape[1] + shape[0] = size + return np.PyArray_SimpleNewFromData(1, shape, np.NPY_INT32, data) + + cdef inline SIZE_t rand_int(SIZE_t low, SIZE_t high, UINT32_t* random_state) nogil: """Generate a random integer in [0; end).""" diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index f63537f4bfdeb..7a1fdbb5585a7 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -303,6 +303,9 @@ def fit(self, X, y, sample_weight=None, check_input=True, ".shape = {})".format(X.shape, X_idx_sorted.shape)) + # Set n_categories (hard-code -1 for now) + n_categories = np.array([-1] * self.n_features_, dtype=np.int32) + # Build tree criterion = self.criterion if not isinstance(criterion, Criterion): @@ -324,7 +327,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, random_state, self.presort) - self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_) + self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_, + n_categories) # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: @@ -340,7 +344,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, max_leaf_nodes, self.min_impurity_split) - builder.build(self.tree_, X, y, sample_weight, X_idx_sorted) + builder.build(self.tree_, X, y, sample_weight, n_categories, + X_idx_sorted) if self.n_outputs_ == 1: self.n_classes_ = self.n_classes_[0] From a65b8482a7f207f480430f943c7025afcf505442 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Thu, 20 Oct 2016 00:07:37 -0700 Subject: [PATCH 05/54] Added a goes_left function to replace threshold comparisons during prediction with trees. Also introduced category caches for quick evaluation of categorical splits. --- sklearn/tree/_tree.pxd | 11 +++++- sklearn/tree/_tree.pyx | 86 ++++++++++++++++++++++++++++++++++++----- sklearn/tree/_utils.pxd | 10 +++++ sklearn/tree/_utils.pyx | 40 +++++++++++++++++++ 4 files changed, 136 insertions(+), 11 deletions(-) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 75f2775ab90e8..4918532bde5ca 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -37,6 +37,15 @@ cdef struct Node: DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node +cdef class CategoryCacheMgr: + # Class to manage the category cache memory during Tree.apply() + + cdef SIZE_t n_nodes + cdef UINT32_t **bits + + cdef void populate(self, Node *nodes, SIZE_t n_nodes, INT32_t *n_categories) + + cdef class Tree: # The Tree object is a binary tree structure constructed by the # TreeBuilder. The tree structure is used for predictions and @@ -61,7 +70,7 @@ cdef class Tree: # Methods cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, - SIZE_t feature, double threshold, double impurity, + SIZE_t feature, SplitValue split_value, double impurity, SIZE_t n_node_samples, double weighted_n_samples) nogil except -1 cdef int _resize(self, SIZE_t capacity) nogil except -1 diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 268ad422bc508..78a7048af9801 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -37,6 +37,8 @@ from ._utils cimport PriorityHeapRecord from ._utils cimport safe_realloc from ._utils cimport sizet_ptr_to_ndarray from ._utils cimport int32_ptr_to_ndarray +from ._utils cimport setup_cat_cache +from ._utils cimport goes_left cdef extern from "numpy/arrayobject.h": object PyArray_NewFromDescr(object subtype, np.dtype descr, @@ -247,7 +249,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): is_leaf = is_leaf or (split.pos >= end) node_id = tree._add_node(parent, is_left, is_leaf, split.feature, - split.split_value.threshold, impurity, n_node_samples, + split.split_value, impurity, n_node_samples, weighted_n_node_samples) if node_id == (-1): @@ -473,7 +475,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if parent != NULL else _TREE_UNDEFINED, is_left, is_leaf, - split.feature, split.split_value.threshold, impurity, n_node_samples, + split.feature, split.split_value, impurity, n_node_samples, weighted_n_node_samples) if node_id == (-1): return -1 @@ -506,6 +508,40 @@ cdef class BestFirstTreeBuilder(TreeBuilder): return 0 +cdef class CategoryCacheMgr: + """Class to manage the category cache memory during Tree.apply() + """ + + def __cinit__(self): + self.n_nodes = 0 + self.bits = NULL + + def _dealloc__(self): + cdef int i + + if self.bits != NULL: + for i in range(self.n_nodes): + free(self.bits[i]) + free(self.bits) + + cdef void populate(self, Node *nodes, SIZE_t n_nodes, INT32_t *n_categories): + cdef SIZE_t i + cdef INT32_t ncat + + if nodes == NULL or n_categories == NULL: + return + + self.n_nodes = n_nodes + safe_realloc( &self.bits, n_nodes, sizeof(void *)) + for i in range(n_nodes): + self.bits[i] = NULL + if nodes[i].left_child != _TREE_LEAF: + ncat = n_categories[nodes[i].feature] + if ncat > 0: + safe_realloc(&self.bits[i], (ncat + 31) // 32, sizeof(UINT32_t)) + setup_cat_cache(self.bits[i], nodes[i].split_value.cat_split, ncat) + + # ============================================================================= # Tree # ============================================================================= @@ -749,7 +785,7 @@ cdef class Tree: return 0 cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, - SIZE_t feature, double threshold, double impurity, + SIZE_t feature, SplitValue split_value, double impurity, SIZE_t n_node_samples, double weighted_n_node_samples) nogil except -1: """Add a node to the tree. @@ -784,7 +820,7 @@ cdef class Tree: else: # left_child and right_child will be set later node.feature = feature - node.split_value.threshold = threshold + node.split_value = split_value self.node_count += 1 @@ -830,17 +866,24 @@ cdef class Tree: # Initialize auxiliary data-structure cdef Node* node = NULL cdef SIZE_t i = 0 + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef UINT32_t** cat_caches = cache_mgr.bits + cdef UINT32_t* cache = NULL with nogil: for i in range(n_samples): node = self.nodes + cache = cat_caches[0] # While node not a leaf while node.left_child != _TREE_LEAF: # ... and node.right_child != _TREE_LEAF: - if X_ptr[X_sample_stride * i + - X_fx_stride * node.feature] <= node.split_value.threshold: + if goes_left(X_ptr[X_sample_stride * i + X_fx_stride * node.feature], + node.split_value, self.n_categories[node.feature], cache): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] out_ptr[i] = (node - self.nodes) # node offset @@ -881,6 +924,10 @@ cdef class Tree: cdef DTYPE_t* X_sample = NULL cdef SIZE_t i = 0 cdef INT32_t k = 0 + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef UINT32_t** cat_caches = cache_mgr.bits + cdef UINT32_t* cache = NULL # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify @@ -895,6 +942,7 @@ cdef class Tree: for i in range(n_samples): node = self.nodes + cache = cat_caches[0] for k in range(X_indptr[i], X_indptr[i + 1]): feature_to_sample[X_indices[k]] = i @@ -909,9 +957,12 @@ cdef class Tree: else: feature_value = 0. - if feature_value <= node.split_value.threshold: + if goes_left(feature_value, node.split_value, + self.n_categories[node.feature], cache): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] out_ptr[i] = (node - self.nodes) # node offset @@ -959,10 +1010,15 @@ cdef class Tree: # Initialize auxiliary data-structure cdef Node* node = NULL cdef SIZE_t i = 0 + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef UINT32_t** cat_caches = cache_mgr.bits + cdef UINT32_t* cache = NULL with nogil: for i in range(n_samples): node = self.nodes + cache = cat_caches[0] indptr_ptr[i + 1] = indptr_ptr[i] # Add all external nodes @@ -971,10 +1027,12 @@ cdef class Tree: indices_ptr[indptr_ptr[i + 1]] = (node - self.nodes) indptr_ptr[i + 1] += 1 - if X_ptr[X_sample_stride * i + - X_fx_stride * node.feature] <= node.split_value.threshold: + if goes_left(X_ptr[X_sample_stride * i + X_fx_stride * node.feature], + node.split_value, self.n_categories[node.feature], cache): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] # Add the leave node @@ -1027,6 +1085,10 @@ cdef class Tree: cdef DTYPE_t* X_sample = NULL cdef SIZE_t i = 0 cdef INT32_t k = 0 + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef UINT32_t** cat_caches = cache_mgr.bits + cdef UINT32_t* cache = NULL # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify @@ -1041,6 +1103,7 @@ cdef class Tree: for i in range(n_samples): node = self.nodes + cache = cat_caches[0] indptr_ptr[i + 1] = indptr_ptr[i] for k in range(X_indptr[i], X_indptr[i + 1]): @@ -1060,9 +1123,12 @@ cdef class Tree: else: feature_value = 0. - if feature_value <= node.split_value.threshold: + if goes_left(feature_value, node.split_value, + self.n_categories[node.feature], cache): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] # Add the leave node diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 07ea94f03e001..6b91a1e909e1d 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -64,6 +64,7 @@ ctypedef fused realloc_ptr: (PriorityHeapRecord*) (void**) (INT32_t*) + (UINT32_t*) cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems, size_t elem_bytes) nogil except * @@ -82,6 +83,15 @@ cdef double rand_uniform(double low, double high, cdef double log(double x) nogil + +cdef void setup_cat_cache(UINT32_t* cachebits, UINT64_t cat_split, + INT32_t n_categories) nogil + + +cdef bint goes_left(DTYPE_t feature_value, SplitValue split, + INT32_t n_categories, UINT32_t* cachebits) nogil + + # ============================================================================= # Stack data structure # ============================================================================= diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 01e93ecc7e189..712d506d8e0ea 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -93,6 +93,46 @@ cdef inline double log(double x) nogil: return ln(x) / ln(2.0) +cdef inline void setup_cat_cache(UINT32_t *cachebits, UINT64_t cat_split, + INT32_t n_categories) nogil: + """Populate the bits of the category cache from a split. + """ + cdef INT32_t j + cdef UINT32_t rng_seed, val + + if n_categories > 0: + if cat_split & 1: + # RandomSplitter + for j in range((n_categories + 31) // 32): + cachebits[j] = 0 + rng_seed = cat_split >> 32 + for j in range(n_categories): + val = rand_int(0, 2, &rng_seed) + cachebits[j // 32] |= val << (j % 32) + else: + # BestSplitter + for j in range((n_categories + 31) // 32): + cachebits[j] = (cat_split >> (j * 32)) & 0xFFFFFFFF + + +cdef inline bint goes_left(DTYPE_t feature_value, SplitValue split, + INT32_t n_categories, UINT32_t* cachebits) nogil: + """Determine whether a sample goes to the left or right child node.""" + cdef SIZE_t idx, shift + + if n_categories < 1: + # Non-categorical feature + return feature_value <= split.threshold + else: + # Categorical feature, using bit cache + if ( feature_value) < n_categories: + idx = ( feature_value) // 32 + shift = ( feature_value) % 32 + return (cachebits[idx] >> shift) & 1 + else: + return 0 + + # ============================================================================= # Stack data structure # ============================================================================= From 2bd5633779da2e790221e20bf9ea9b1c6c9c19d3 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Sun, 8 Jan 2017 15:31:57 -0800 Subject: [PATCH 06/54] BestSplitter now calculates the best categorical split. --- sklearn/tree/_splitter.pxd | 2 + sklearn/tree/_splitter.pyx | 158 +++++++++++++++++++++++++------------ 2 files changed, 109 insertions(+), 51 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 5373e31dfa985..00e1390b797cc 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -21,6 +21,7 @@ ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer cdef struct SplitRecord: # Data to track sample split @@ -68,6 +69,7 @@ cdef class Splitter: cdef DOUBLE_t* sample_weight cdef INT32_t *n_categories # (n_features,) array giving number of # categories (<0 for non-categorical) + cdef UINT32_t* cat_cache # Cache buffer for fast categorical split evaluation # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 644c5a4d95e58..63075d3b1d5c5 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -33,6 +33,8 @@ from ._utils cimport rand_int from ._utils cimport rand_uniform from ._utils cimport RAND_R_MAX from ._utils cimport safe_realloc +from ._utils cimport setup_cat_cache +from ._utils cimport goes_left cdef double INFINITY = np.inf @@ -96,6 +98,7 @@ cdef class Splitter: self.y_stride = 0 self.sample_weight = NULL self.n_categories = NULL + self.cat_cache = NULL self.max_features = max_features self.min_samples_leaf = min_samples_leaf @@ -111,6 +114,7 @@ cdef class Splitter: free(self.constant_features) free(self.feature_values) free(self.n_categories) + free(self.cat_cache) def __getstate__(self): return {} @@ -198,6 +202,12 @@ cdef class Splitter: self.n_categories[i] = (-1 if n_categories == NULL else n_categories[i]) + # If needed, allocate cache space for categorical splits + cdef INT32_t max_n_categories = max( + [self.n_categories[i] for i in range(n_features)]) + if max_n_categories > 0: + safe_realloc(&self.cat_cache, (max_n_categories + 31) // 32, sizeof(UINT32_t)) + return 0 cdef int node_reset(self, SIZE_t start, SIZE_t end, @@ -361,7 +371,6 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t f_i = n_features cdef SIZE_t f_j - cdef SIZE_t tmp cdef SIZE_t p cdef SIZE_t feature_idx_offset cdef SIZE_t feature_offset @@ -378,6 +387,10 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t n_total_constants = n_known_constants cdef DTYPE_t current_feature_value cdef SIZE_t partition_end + cdef bint is_categorical + cdef UINT64_t cat_idx, cat_split + cdef SIZE_t ncat_present + cdef INT32_t cat_offs[64] _init_split(&best, end) @@ -419,9 +432,8 @@ cdef class BestSplitter(BaseDenseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 @@ -463,63 +475,113 @@ cdef class BestSplitter(BaseDenseSplitter): f_i -= 1 features[f_i], features[f_j] = features[f_j], features[f_i] + # Identify the number of categories present in this node + is_categorical = self.n_categories[current.feature] > 0 + if is_categorical: + cat_split = 0 + ncat_present = 0 + for i in range(start, end): + # Xf[i] < 64 already verified in tree.py + cat_split |= ( 1) << ( Xf[i]) + for i in range(self.n_categories[current.feature]): + if (cat_split >> i) & 1: + cat_offs[ncat_present] = i - ncat_present + ncat_present += 1 + # Evaluate all splits self.criterion.reset() p = start + cat_idx = 0 + + while True: + if is_categorical: + cat_idx += 1 + if cat_idx >= ( 1) << (ncat_present - 1): + break + + # Expand the bits of (2 * cat_idx) out into cat_split + # We double cat_idx to avoid double-counting equivalent splits + # This also ensures that cat_split & 1 == 0 as required + cat_split = 0 + for i in range(ncat_present): + cat_split |= ((cat_idx << 1) & (( 1) << i)) << cat_offs[i] + + # Partition + j = start + partition_end = end + while j < partition_end: + if (cat_split >> ( Xf[j])) & 1: + j += 1 + else: + partition_end -= 1 + Xf[j], Xf[partition_end] = Xf[partition_end], Xf[j] + samples[j], samples[partition_end] = ( + samples[partition_end], samples[j]) + current.pos = j + + # Must reset criterion since we've reordered the samples + self.criterion.reset() + else: + # Non-categorical feature + while (p + 1 < end and + Xf[p + 1] <= Xf[p] + FEATURE_THRESHOLD): + p += 1 - while p < end: - while (p + 1 < end and - Xf[p + 1] <= Xf[p] + FEATURE_THRESHOLD): + # (p + 1 >= end) or (X[samples[p + 1], current.feature] > + # X[samples[p], current.feature]) p += 1 + # (p >= end) or (X[samples[p], current.feature] > + # X[samples[p - 1], current.feature]) - # (p + 1 >= end) or (X[samples[p + 1], current.feature] > - # X[samples[p], current.feature]) - p += 1 - # (p >= end) or (X[samples[p], current.feature] > - # X[samples[p - 1], current.feature]) + if p >= end: + break - if p < end: current.pos = p - # Reject if min_samples_leaf is not guaranteed - if (((current.pos - start) < min_samples_leaf) or - ((end - current.pos) < min_samples_leaf)): - continue + # Reject if min_samples_leaf is not guaranteed + if (((current.pos - start) < min_samples_leaf) or + ((end - current.pos) < min_samples_leaf)): + continue - self.criterion.update(current.pos) + self.criterion.update(current.pos) - # Reject if min_weight_leaf is not satisfied - if ((self.criterion.weighted_n_left < min_weight_leaf) or - (self.criterion.weighted_n_right < min_weight_leaf)): - continue + # Reject if min_weight_leaf is not satisfied + if ((self.criterion.weighted_n_left < min_weight_leaf) or + (self.criterion.weighted_n_right < min_weight_leaf)): + continue - current_proxy_improvement = self.criterion.proxy_impurity_improvement() + current_proxy_improvement = self.criterion.proxy_impurity_improvement() - if current_proxy_improvement > best_proxy_improvement: - best_proxy_improvement = current_proxy_improvement + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + if is_categorical: + current.split_value.cat_split = cat_split + else: current.split_value.threshold = (Xf[p - 1] + Xf[p]) / 2.0 + if current.split_value.threshold == Xf[p]: + current.split_value.threshold = Xf[p - 1] - if current.split_value.threshold == Xf[p]: - current.split_value.threshold = Xf[p - 1] - - best = current # copy + best = current # copy # Reorganize into samples[start:best.pos] + samples[best.pos:end] if best.pos < end: + setup_cat_cache(self.cat_cache, best.split_value.cat_split, + self.n_categories[best.feature]) feature_offset = X_feature_stride * best.feature partition_end = end p = start while p < partition_end: - if X[X_sample_stride * samples[p] + feature_offset] <= best.split_value.threshold: + if goes_left(X[X_sample_stride * samples[p] + feature_offset], + best.split_value, self.n_categories[best.feature], + self.cat_cache): p += 1 else: partition_end -= 1 - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + samples[partition_end], samples[p] = ( + samples[p], samples[partition_end]) self.criterion.reset() self.criterion.update(best.pos) @@ -702,7 +764,6 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef SIZE_t f_i = n_features cdef SIZE_t f_j cdef SIZE_t p - cdef SIZE_t tmp cdef SIZE_t feature_stride # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 @@ -752,9 +813,8 @@ cdef class RandomSplitter(BaseDenseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 @@ -811,9 +871,8 @@ cdef class RandomSplitter(BaseDenseSplitter): Xf[p] = Xf[partition_end] Xf[partition_end] = current_feature_value - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + samples[partition_end], samples[p] = ( + samples[p], samples[partition_end]) current.pos = partition_end @@ -851,9 +910,8 @@ cdef class RandomSplitter(BaseDenseSplitter): else: partition_end -= 1 - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + samples[partition_end], samples[p] = ( + samples[p], samples[partition_end]) self.criterion.reset() @@ -1248,7 +1306,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): cdef double best_proxy_improvement = - INFINITY cdef SIZE_t f_i = n_features - cdef SIZE_t f_j, p, tmp + cdef SIZE_t f_j, p cdef SIZE_t n_visited_features = 0 # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 @@ -1303,9 +1361,8 @@ cdef class BestSparseSplitter(BaseSparseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 @@ -1480,7 +1537,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): cdef DTYPE_t current_feature_value cdef SIZE_t f_i = n_features - cdef SIZE_t f_j, p, tmp + cdef SIZE_t f_j, p cdef SIZE_t n_visited_features = 0 # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 @@ -1536,9 +1593,8 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 From e3f0a995604002d7230380b6bb30482ba0d3b23b Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Sat, 29 Oct 2016 00:39:11 -0700 Subject: [PATCH 07/54] Added categorical split code to RandomSplitter.node_split --- sklearn/tree/_splitter.pyx | 65 +++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 63075d3b1d5c5..e66aefbe8694d 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -763,7 +763,7 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef SIZE_t f_i = n_features cdef SIZE_t f_j - cdef SIZE_t p + cdef SIZE_t p, q cdef SIZE_t feature_stride # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 @@ -777,6 +777,8 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef DTYPE_t max_feature_value cdef DTYPE_t current_feature_value cdef SIZE_t partition_end + cdef bint is_categorical + cdef UINT64_t split_seed _init_split(&best, end) @@ -851,30 +853,45 @@ cdef class RandomSplitter(BaseDenseSplitter): f_i -= 1 features[f_i], features[f_j] = features[f_j], features[f_i] - # Draw a random threshold - current.split_value.threshold = rand_uniform( - min_feature_value, max_feature_value, random_state) - - if current.split_value.threshold == max_feature_value: - current.split_value.threshold = min_feature_value - - # Partition - partition_end = end - p = start - while p < partition_end: - current_feature_value = Xf[p] - if current_feature_value <= current.split_value.threshold: - p += 1 + # Repeat split & partition if split is trivial, up to 60 times + # (Can only happen with categorical features) + for q in range(60): + # Construct a random split + is_categorical = self.n_categories[current.feature] > 0 + if is_categorical: + split_seed = rand_int(0, RAND_R_MAX + 1, + random_state) + current.split_value.cat_split = (split_seed << 32) | 1 else: - partition_end -= 1 + current.split_value.threshold = rand_uniform( + min_feature_value, max_feature_value, random_state) + if current.split_value.threshold == max_feature_value: + current.split_value.threshold = min_feature_value + + # Partition + setup_cat_cache(self.cat_cache, current.split_value.cat_split, + self.n_categories[current.feature]) + partition_end = end + p = start + while p < partition_end: + current_feature_value = Xf[p] + if goes_left(current_feature_value, current.split_value, + self.n_categories[current.feature], self.cat_cache): + p += 1 + else: + partition_end -= 1 + + Xf[p] = Xf[partition_end] + Xf[partition_end] = current_feature_value - Xf[p] = Xf[partition_end] - Xf[partition_end] = current_feature_value + samples[partition_end], samples[p] = ( + samples[p], samples[partition_end]) - samples[partition_end], samples[p] = ( - samples[p], samples[partition_end]) + current.pos = partition_end - current.pos = partition_end + # Break early if the split is non-trivial + if current.pos != start and current.pos != end: + break # Reject if min_samples_leaf is not guaranteed if (((current.pos - start) < min_samples_leaf) or @@ -899,12 +916,16 @@ cdef class RandomSplitter(BaseDenseSplitter): # Reorganize into samples[start:best.pos] + samples[best.pos:end] feature_stride = X_feature_stride * best.feature if best.pos < end: + setup_cat_cache(self.cat_cache, best.split_value.cat_split, + self.n_categories[best.feature]) if current.feature != best.feature: partition_end = end p = start while p < partition_end: - if X[X_sample_stride * samples[p] + feature_stride] <= best.split_value.threshold: + if goes_left(X[X_sample_stride * samples[p] + feature_stride], + best.split_value, self.n_categories[best.feature], + self.cat_cache): p += 1 else: From 35dad263f0c8d8b6b9a1abb32a799422ac801fb2 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Sun, 8 Jan 2017 15:46:21 -0800 Subject: [PATCH 08/54] Added an implementation of the Breiman sorting shortcut for finding the best categorical split. --- sklearn/tree/_splitter.pxd | 2 + sklearn/tree/_splitter.pyx | 85 +++++++++++++++++++++++++++++++++----- sklearn/tree/tree.py | 9 +++- 3 files changed, 84 insertions(+), 12 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 00e1390b797cc..035a5dd48ee9d 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -63,6 +63,8 @@ cdef class Splitter: cdef bint presort # Whether to use presorting, only # allowed on dense data + cdef bint breiman_shortcut # Whether decision trees are allowed to use the + # Breiman shortcut for categorical features cdef DOUBLE_t* y cdef SIZE_t y_stride diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index e66aefbe8694d..222865b10e3a4 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -62,7 +62,7 @@ cdef class Splitter: def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint breiman_shortcut): """ Parameters ---------- @@ -105,6 +105,7 @@ cdef class Splitter: self.min_weight_leaf = min_weight_leaf self.random_state = random_state self.presort = presort + self.breiman_shortcut = breiman_shortcut def __dealloc__(self): """Destructor.""" @@ -277,7 +278,7 @@ cdef class BaseDenseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint breiman_shortcut): self.X = NULL self.X_sample_stride = 0 @@ -337,6 +338,49 @@ cdef class BestSplitter(BaseDenseSplitter): self.random_state, self.presort), self.__getstate__()) + + cdef void _breiman_sort_categories(self, SIZE_t start, SIZE_t end, INT32_t ncat, + SIZE_t ncat_present, const INT32_t *cat_offs, + SIZE_t *sorted_cat) nogil: + """The Breiman shortcut for finding the best split involves a + preprocessing step wherein we sort the categories by + increasing (weighted) mean of the outcome y (whether 0/1 + binary for classification or quantitative for + regression). This function implements this preprocessing step + and produces a sorted list of category values. + """ + cdef: + SIZE_t *samples = self.samples + DTYPE_t *Xf = self.feature_values + DOUBLE_t *y = self.y + SIZE_t y_stride = self.y_stride + DOUBLE_t *sample_weight = self.sample_weight + DOUBLE_t w + SIZE_t cat, localcat + SIZE_t q, partition_end + DTYPE_t sort_value[64] + DTYPE_t sort_den[64] + + for cat in range(ncat): + sort_value[cat] = 0 + sort_den[cat] = 0 + + for q in range(start, end): + cat = Xf[q] + w = sample_weight[samples[q]] if sample_weight else 1.0 + sort_value[cat] += w * (y[y_stride * samples[q]]) + sort_den[cat] += w + + for localcat in range(ncat_present): + cat = localcat + cat_offs[localcat] + if sort_den[cat] == 0: # Avoid dividing zero by zero + sort_den[cat] = 1 + sort_value[localcat] = sort_value[cat] / sort_den[cat] + sorted_cat[localcat] = cat + + sort(&sort_value[0], sorted_cat, ncat_present) + + cdef int node_split(self, double impurity, SplitRecord* split, SIZE_t* n_constant_features) nogil except -1: """Find the best split on node samples[start:end] @@ -391,6 +435,8 @@ cdef class BestSplitter(BaseDenseSplitter): cdef UINT64_t cat_idx, cat_split cdef SIZE_t ncat_present cdef INT32_t cat_offs[64] + cdef bint breiman_shortcut = self.breiman_shortcut + cdef SIZE_t sorted_cat[64] _init_split(&best, end) @@ -487,6 +533,12 @@ cdef class BestSplitter(BaseDenseSplitter): if (cat_split >> i) & 1: cat_offs[ncat_present] = i - ncat_present ncat_present += 1 + if ncat_present <= 3: + breiman_shortcut = False # No benefit for small N + if breiman_shortcut: + self._breiman_sort_categories( + start, end, self.n_categories[current.feature], + ncat_present, cat_offs, &sorted_cat[0]) # Evaluate all splits self.criterion.reset() @@ -496,15 +548,26 @@ cdef class BestSplitter(BaseDenseSplitter): while True: if is_categorical: cat_idx += 1 - if cat_idx >= ( 1) << (ncat_present - 1): - break + if breiman_shortcut: + if cat_idx >= ncat_present: + break + + cat_split = 0 + for i in range(cat_idx): + cat_split |= ( 1) << sorted_cat[i] + if cat_split & 1: + cat_split = (~cat_split) & ( + (~( 0)) >> (64 - self.n_categories[current.feature])) + else: + if cat_idx >= ( 1) << (ncat_present - 1): + break - # Expand the bits of (2 * cat_idx) out into cat_split - # We double cat_idx to avoid double-counting equivalent splits - # This also ensures that cat_split & 1 == 0 as required - cat_split = 0 - for i in range(ncat_present): - cat_split |= ((cat_idx << 1) & (( 1) << i)) << cat_offs[i] + # Expand the bits of (2 * cat_idx) out into cat_split + # We double cat_idx to avoid double-counting equivalent splits + # This also ensures that cat_split & 1 == 0 as required + cat_split = 0 + for i in range(ncat_present): + cat_split |= ((cat_idx << 1) & (( 1) << i)) << cat_offs[i] # Partition j = start @@ -970,7 +1033,7 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint breiman_shortcut): # Parent __cinit__ is automatically called self.X_data = NULL diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 7a1fdbb5585a7..f0ad84cdfd41f 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -315,6 +315,12 @@ def fit(self, X, y, sample_weight=None, check_input=True, else: criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples) + if is_classification: + breiman_shortcut = (self.n_classes_.tolist() == [2] and + (isinstance(criterion, _criterion.Gini) or + isinstance(criterion, _criterion.Entropy))) + else: + breiman_shortcut = isinstance(criterion, _criterion.MSE) SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS @@ -325,7 +331,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, min_samples_leaf, min_weight_leaf, random_state, - self.presort) + self.presort, + breiman_shortcut) self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_, n_categories) From 2312ce03f75f215ea990c66609d02a29f2c21aca Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Fri, 4 Nov 2016 09:16:18 -0700 Subject: [PATCH 09/54] Added categorical constructor parameter and error checking to BaseDecisionTree. --- sklearn/tree/tree.py | 123 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 111 insertions(+), 12 deletions(-) diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index f0ad84cdfd41f..1674d35abff9e 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -21,6 +21,7 @@ from abc import ABCMeta from abc import abstractmethod from math import ceil +import warnings import numpy as np from scipy.sparse import issparse @@ -91,7 +92,8 @@ def __init__(self, random_state, min_impurity_split, class_weight=None, - presort=False): + presort=False, + categorical="none"): self.criterion = criterion self.splitter = splitter self.max_depth = max_depth @@ -104,6 +106,16 @@ def __init__(self, self.min_impurity_split = min_impurity_split self.class_weight = class_weight self.presort = presort + self.categorical = categorical + + # Input validation for parameter categorical + if isinstance(self.categorical, str): + if categorical not in ('all', 'none'): + raise ValueError("Invalid value for categorical: {}. Allowed" + " strings are 'all' or 'none'" + "".format(categorical)) + elif len(np.shape(categorical)) != 1: + raise ValueError("Invalid shape for categorical") def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): @@ -264,6 +276,54 @@ def fit(self, X, y, sample_weight=None, check_input=True, else: sample_weight = expanded_class_weight + # Validate categorical features + if isinstance(self.categorical, str): + if self.categorical == 'none': + categorical = np.array([], dtype=np.int) + elif self.categorical == 'all': + categorical = np.arange(self.n_features_) + else: + # Should have been caught in the constructor, but just in case + raise ValueError("Invalid value for categorical: {}. Allowed" + " strings are 'all' or 'none'" + "".format(self.categorical)) + else: + categorical = np.atleast_1d(self.categorical).flatten() + if categorical.dtype == np.bool: + if categorical.size != self.n_features_: + raise ValueError("Shape of boolean parameter categorical must" + " be (n_features,)") + categorical = np.nonzero(categorical)[0] + if (np.size(categorical) > self.n_features_ or + (categorical.size > 0 and + (categorical.min() < 0 or + categorical.max() >= self.n_features_))): + raise ValueError("Invalid shape or invalid feature index for" + " parameter categorical") + if issparse(X): + if categorical.size > 0: + raise NotImplementedError("Categorical features not supported" + " with sparse inputs") + else: + if np.any(X[:, categorical].astype(np.int) < 0): + raise ValueError("Invalid training data: categorical values" + " must be non-negative.") + + # Calculate n_categories and verify they are all at least 1% populated + n_categories = np.array([np.int(X[:, i].max()) + 1 if i in categorical + else -1 for i in range(self.n_features_)], + dtype=np.int32) + n_cat_present = np.array([np.unique(X[:, i].astype(np.int)).size + if i in categorical else -1 + for i in range(self.n_features_)], + dtype=np.int32) + if np.any((n_cat_present < 0.01 * n_cat_present)[categorical]): + warnings.warn("At least one categorical feature has less than 1%" + " of its categories present in the sample. Runtime" + " and memory usage will be much smaller if you" + " represent the categories as sequential integers.", + UserWarning) + # Set min_weight_leaf from min_weight_fraction_leaf if sample_weight is None: min_weight_leaf = (self.min_weight_fraction_leaf * @@ -303,9 +363,6 @@ def fit(self, X, y, sample_weight=None, check_input=True, ".shape = {})".format(X.shape, X_idx_sorted.shape)) - # Set n_categories (hard-code -1 for now) - n_categories = np.array([-1] * self.n_features_, dtype=np.int32) - # Build tree criterion = self.criterion if not isinstance(criterion, Criterion): @@ -334,6 +391,12 @@ def fit(self, X, y, sample_weight=None, check_input=True, self.presort, breiman_shortcut) + if (not isinstance(splitter, _splitter.RandomSplitter) and + np.max(n_categories) > 64): + raise ValueError("Categorical features with greater than 64" + " categories not supported with DecisionTree;" + " try ExtraTree.") + self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_, n_categories) @@ -372,6 +435,9 @@ def _validate_X_predict(self, X, check_input): X.indptr.dtype != np.intc): raise ValueError("No support for np.int64 index based " "sparse matrices") + if issparse(X) and np.any(self.tree_.n_categories > 0): + raise NotImplementedError("Categorical features not supported" + " with sparse inputs") n_features = X.shape[1] if self.n_features_ != n_features: @@ -612,6 +678,19 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + Attributes ---------- classes_ : array of shape = [n_classes] or a list of such arrays @@ -684,7 +763,8 @@ def __init__(self, max_leaf_nodes=None, min_impurity_split=1e-7, class_weight=None, - presort=False): + presort=False, + categorical="none"): super(DecisionTreeClassifier, self).__init__( criterion=criterion, splitter=splitter, @@ -697,7 +777,8 @@ def __init__(self, class_weight=class_weight, random_state=random_state, min_impurity_split=min_impurity_split, - presort=presort) + presort=presort, + categorical=categorical) def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): @@ -913,6 +994,18 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data using the + ``MSE`` criterion. In this case, the runtime is linear in the number + of categories. Extra-random trees have an upper limit of :math:`2^{31}` + categories, and runtimes linear in the number of categories. + Attributes ---------- feature_importances_ : array of shape = [n_features] @@ -976,7 +1069,8 @@ def __init__(self, random_state=None, max_leaf_nodes=None, min_impurity_split=1e-7, - presort=False): + presort=False, + categorical="none"): super(DecisionTreeRegressor, self).__init__( criterion=criterion, splitter=splitter, @@ -988,7 +1082,8 @@ def __init__(self, max_leaf_nodes=max_leaf_nodes, random_state=random_state, min_impurity_split=min_impurity_split, - presort=presort) + presort=presort, + categorical=categorical) def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): @@ -1069,7 +1164,8 @@ def __init__(self, random_state=None, max_leaf_nodes=None, min_impurity_split=1e-7, - class_weight=None): + class_weight=None, + categorical="none"): super(ExtraTreeClassifier, self).__init__( criterion=criterion, splitter=splitter, @@ -1081,7 +1177,8 @@ def __init__(self, max_leaf_nodes=max_leaf_nodes, class_weight=class_weight, min_impurity_split=min_impurity_split, - random_state=random_state) + random_state=random_state, + categorical=categorical) class ExtraTreeRegressor(DecisionTreeRegressor): @@ -1118,7 +1215,8 @@ def __init__(self, max_features="auto", random_state=None, min_impurity_split=1e-7, - max_leaf_nodes=None): + max_leaf_nodes=None, + categorical="none"): super(ExtraTreeRegressor, self).__init__( criterion=criterion, splitter=splitter, @@ -1129,4 +1227,5 @@ def __init__(self, max_features=max_features, max_leaf_nodes=max_leaf_nodes, min_impurity_split=min_impurity_split, - random_state=random_state) + random_state=random_state, + categorical=categorical) From e0068f8622f3223b85fd00f742bb147ea3d32bee Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Tue, 22 Nov 2016 17:40:49 -0800 Subject: [PATCH 10/54] Added the categorical keyword to forest constructors. --- sklearn/ensemble/forest.py | 63 +++++++++++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 1c160be7870bc..176b623263a6d 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -813,6 +813,17 @@ class RandomForestClassifier(ForestClassifier): .. versionadded:: 0.18 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + bootstrap : boolean, optional (default=True) Whether bootstrap samples are used when building trees. @@ -908,6 +919,7 @@ def __init__(self, max_features="auto", max_leaf_nodes=None, min_impurity_split=1e-7, + categorical="none", bootstrap=True, oob_score=False, n_jobs=1, @@ -921,7 +933,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -938,6 +950,7 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split + self.categorical = categorical class RandomForestRegressor(ForestRegressor): @@ -1025,6 +1038,17 @@ class RandomForestRegressor(ForestRegressor): .. versionadded:: 0.18 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + bootstrap : boolean, optional (default=True) Whether bootstrap samples are used when building trees. @@ -1089,6 +1113,7 @@ def __init__(self, max_features="auto", max_leaf_nodes=None, min_impurity_split=1e-7, + categorical="none", bootstrap=True, oob_score=False, n_jobs=1, @@ -1101,7 +1126,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1117,6 +1142,7 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split + self.categorical = categorical class ExtraTreesClassifier(ForestClassifier): @@ -1197,6 +1223,13 @@ class ExtraTreesClassifier(ForestClassifier): .. versionadded:: 0.18 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + bootstrap : boolean, optional (default=False) Whether bootstrap samples are used when building trees. @@ -1293,6 +1326,7 @@ def __init__(self, max_features="auto", max_leaf_nodes=None, min_impurity_split=1e-7, + categorical="none", bootstrap=False, oob_score=False, n_jobs=1, @@ -1306,7 +1340,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1323,6 +1357,7 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split + self.categorical = categorical class ExtraTreesRegressor(ForestRegressor): @@ -1408,6 +1443,13 @@ class ExtraTreesRegressor(ForestRegressor): .. versionadded:: 0.18 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + bootstrap : boolean, optional (default=False) Whether bootstrap samples are used when building trees. @@ -1473,6 +1515,7 @@ def __init__(self, max_features="auto", max_leaf_nodes=None, min_impurity_split=1e-7, + categorical="none", bootstrap=False, oob_score=False, n_jobs=1, @@ -1485,7 +1528,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1501,6 +1544,7 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split + self.categorical = categorical class RandomTreesEmbedding(BaseForest): @@ -1566,6 +1610,13 @@ class RandomTreesEmbedding(BaseForest): .. versionadded:: 0.18 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + sparse_output : bool, optional (default=True) Whether or not to return a sparse CSR matrix, as default behavior, or to return a dense array compatible with dense pipeline operators. @@ -1611,6 +1662,7 @@ def __init__(self, min_weight_fraction_leaf=0., max_leaf_nodes=None, min_impurity_split=1e-7, + categorical="none", sparse_output=True, n_jobs=1, random_state=None, @@ -1622,7 +1674,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=False, oob_score=False, n_jobs=n_jobs, @@ -1639,6 +1691,7 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split self.sparse_output = sparse_output + self.categorical = categorical def _set_oob_score(self, X, y): raise NotImplementedError("OOB score not supported by tree embedding") From 1b7732603e29ae594da6b2ae74facc24cfec17e3 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Sun, 4 Dec 2016 18:33:09 -0800 Subject: [PATCH 11/54] Added some unit tests. --- sklearn/tree/tests/test_tree.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index ff662e9af414a..774391974757f 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1634,3 +1634,34 @@ def _pickle_copy(obj): assert_equal(typename, typename_) assert_equal(n_outputs, n_outputs_) assert_equal(n_samples, n_samples_) + + +def check_invalid_categorical(name): + Tree = ALL_TREES[name] + raise_on_construction = ['invalid string', [[0]]] + raise_on_fit = [[False, False, False], [1, 2], [-3], [0, 0, 1]] + for catval in raise_on_construction: + assert_raises(ValueError, Tree, categorical=catval) + for catval in raise_on_fit: + assert_raises(ValueError, Tree(categorical=catval).fit, X, y) + + +def test_invalid_categorical(): + for name in ALL_TREES: + yield check_invalid_categorical, name + + +def check_no_sparse_with_categorical(name): + X, y, X_sparse = [DATASETS['clf_small'][z] + for z in ['X', 'y', 'X_sparse']] + Tree = ALL_TREES[name] + assert_raises(NotImplementedError, Tree(categorical=[6, 10]).fit, + X_sparse, y) + assert_raises(NotImplementedError, + Tree(categorical=[6, 10]).fit(X, y).predict, X_sparse) + + +def test_no_sparse_with_categorical(): + # Currently we do not support sparse categorical features + for name in SPARSE_TREES: + yield check_no_sparse_with_categorical, name From 149b7bfb2766549ab4002068cb60886be86124a0 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Tue, 20 Dec 2016 20:59:44 -0800 Subject: [PATCH 12/54] Refactored _partial_dependence_tree a little. --- sklearn/ensemble/_gradient_boosting.pyx | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index 43bced4b46742..dd0f7b985251f 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -290,27 +290,25 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, cdef SIZE_t node_count = tree.node_count cdef SIZE_t stack_capacity = node_count * 2 - cdef Node **node_stack cdef double[::1] weight_stack = np_ones((stack_capacity,), dtype=np_float64) cdef SIZE_t stack_size = 1 cdef double left_sample_frac cdef double current_weight cdef double total_weight = 0.0 cdef Node *current_node - underlying_stack = np_zeros((stack_capacity,), dtype=np.intp) - node_stack = ( underlying_stack).data + cdef SIZE_t[::1] node_stack = np_zeros((stack_capacity,), dtype=np.intp) for i in range(X.shape[0]): # init stacks for new example stack_size = 1 - node_stack[0] = root_node + node_stack[0] = 0 weight_stack[0] = 1.0 total_weight = 0.0 while stack_size > 0: # get top node on stack stack_size -= 1 - current_node = node_stack[stack_size] + current_node = root_node + node_stack[stack_size] if current_node.left_child == TREE_LEAF: out[i] += weight_stack[stack_size] * value[current_node - root_node] * \ @@ -324,19 +322,17 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, # push left or right child on stack if X[i, feature_index] <= current_node.split_value.threshold: # left - node_stack[stack_size] = (root_node + - current_node.left_child) + node_stack[stack_size] = current_node.left_child else: # right - node_stack[stack_size] = (root_node + - current_node.right_child) + node_stack[stack_size] = current_node.right_child stack_size += 1 else: # split feature in complement set # push both children onto stack # push left child - node_stack[stack_size] = root_node + current_node.left_child + node_stack[stack_size] = current_node.left_child current_weight = weight_stack[stack_size] left_sample_frac = root_node[current_node.left_child].n_node_samples / \ current_node.n_node_samples @@ -351,7 +347,7 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, stack_size +=1 # push right child - node_stack[stack_size] = root_node + current_node.right_child + node_stack[stack_size] = current_node.right_child weight_stack[stack_size] = current_weight * \ (1.0 - left_sample_frac) stack_size +=1 From b65da1c09e9b3b1de6fd177018407e309ae36678 Mon Sep 17 00:00:00 2001 From: Jeffrey Blackburne Date: Tue, 20 Dec 2016 22:58:10 -0800 Subject: [PATCH 13/54] Added categorical support to gradient boosting. --- sklearn/ensemble/_gradient_boosting.pyx | 37 +++++++++++++++++++++++-- sklearn/ensemble/gradient_boosting.py | 36 ++++++++++++++++++++---- 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index dd0f7b985251f..d7b9ed7e64445 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -20,10 +20,13 @@ from scipy.sparse import csr_matrix from sklearn.tree._tree cimport Node from sklearn.tree._tree cimport Tree +from sklearn.tree._tree cimport CategoryCacheMgr from sklearn.tree._tree cimport DTYPE_t from sklearn.tree._tree cimport SIZE_t from sklearn.tree._tree cimport INT32_t +from sklearn.tree._tree cimport UINT32_t from sklearn.tree._utils cimport safe_realloc +from sklearn.tree._utils cimport goes_left ctypedef np.int32_t int32 ctypedef np.float64_t float64 @@ -48,6 +51,8 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, Py_ssize_t K, Py_ssize_t n_samples, Py_ssize_t n_features, + INT32_t* n_categories, + UINT32_t** cachebits, float64 *out): """Predicts output for regression tree and stores it in ``out[i, k]``. @@ -82,6 +87,12 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, ``n_samples == X.shape[0]``. n_features : int The number of features; ``n_samples == X.shape[1]``. + n_categories : INT32_t pointer + Array of length n_features containing the number of categories + (for categorical features) or -1 (for non-categorical features) + cachebits : UINT32_t pointer pointer + Array of length node_count containing category cache buffers + for categorical features out : np.float64_t pointer The pointer to the data array where the predictions are stored. ``out`` is assumed to be a two-dimensional array of @@ -89,13 +100,19 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, """ cdef Py_ssize_t i cdef Node *node + cdef UINT32_t* node_cache + for i in range(n_samples): node = root_node + node_cache = cachebits[0] # While node not a leaf while node.left_child != TREE_LEAF: - if X[i * n_features + node.feature] <= node.split_value.threshold: + if goes_left(X[i * n_features + node.feature], node.split_value, + n_categories[node.feature], node_cache): + node_cache = cachebits[node.left_child] node = root_node + node.left_child else: + node_cache = cachebits[node.right_child] node = root_node + node.right_child out[i * K + k] += scale * value[node - root_node] @@ -213,6 +230,10 @@ def predict_stages(np.ndarray[object, ndim=2] estimators, for k in range(K): tree = estimators[i, k].tree_ + # Make category cache buffers for this tree's nodes + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(tree.nodes, tree.node_count, tree.n_categories) + # avoid buffer validation by casting to ndarray # and get data pointer # need brackets because of casting operator priority @@ -220,6 +241,7 @@ def predict_stages(np.ndarray[object, ndim=2] estimators, ( X).data, tree.nodes, tree.value, scale, k, K, X.shape[0], X.shape[1], + tree.n_categories, cache_mgr.bits, ( out).data) ## out += scale * tree.predict(X).reshape((X.shape[0], 1)) @@ -297,11 +319,19 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, cdef double total_weight = 0.0 cdef Node *current_node cdef SIZE_t[::1] node_stack = np_zeros((stack_capacity,), dtype=np.intp) + cdef UINT32_t** cachebits + cdef UINT32_t* node_cache + + # Make category cache buffers for this tree's nodes + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(root_node, node_count, tree.n_categories) + cachebits = cache_mgr.bits for i in range(X.shape[0]): # init stacks for new example stack_size = 1 node_stack[0] = 0 + node_cache = cachebits[0] weight_stack[0] = 1.0 total_weight = 0.0 @@ -309,6 +339,7 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, # get top node on stack stack_size -= 1 current_node = root_node + node_stack[stack_size] + node_cache = cachebits[node_stack[stack_size]] if current_node.left_child == TREE_LEAF: out[i] += weight_stack[stack_size] * value[current_node - root_node] * \ @@ -320,7 +351,9 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, if feature_index != -1: # split feature in target set # push left or right child on stack - if X[i, feature_index] <= current_node.split_value.threshold: + if goes_left(X[i, feature_index], current_node.split_value, + tree.n_categories[current_node.feature], + node_cache): # left node_stack[stack_size] = current_node.left_child else: diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 49a3fd1a9e348..e8c05f7405ead 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -722,7 +722,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, min_samples_split, min_samples_leaf, min_weight_fraction_leaf, max_depth, min_impurity_split, init, subsample, max_features, random_state, alpha=0.9, verbose=0, max_leaf_nodes=None, - warm_start=False, presort='auto'): + warm_start=False, presort='auto', categorical='none'): self.n_estimators = n_estimators self.learning_rate = learning_rate @@ -742,6 +742,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, self.max_leaf_nodes = max_leaf_nodes self.warm_start = warm_start self.presort = presort + self.categorical = categorical def _fit_stage(self, i, X, y, y_pred, sample_weight, sample_mask, random_state, X_idx_sorted, X_csc=None, X_csr=None): @@ -770,7 +771,8 @@ def _fit_stage(self, i, X, y, y_pred, sample_weight, sample_mask, max_features=self.max_features, max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, - presort=self.presort) + presort=self.presort, + categorical=self.categorical) if self.subsample < 1.0: # no inplace multiplication! @@ -1357,6 +1359,17 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): .. versionadded:: 0.17 *presort* parameter. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + Attributes ---------- feature_importances_ : array, shape = [n_features] @@ -1409,7 +1422,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, max_depth=3, min_impurity_split=1e-7, init=None, random_state=None, max_features=None, verbose=0, max_leaf_nodes=None, warm_start=False, - presort='auto'): + presort='auto', categorical='none'): super(GradientBoostingClassifier, self).__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -1422,7 +1435,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, max_leaf_nodes=max_leaf_nodes, min_impurity_split=min_impurity_split, warm_start=warm_start, - presort=presort) + presort=presort, categorical=categorical) def _validate_y(self, y): check_classification_targets(y) @@ -1744,6 +1757,17 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): .. versionadded:: 0.17 optional parameter *presort*. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + Attributes ---------- feature_importances_ : array, shape = [n_features] @@ -1792,7 +1816,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, min_samples_leaf=1, min_weight_fraction_leaf=0., max_depth=3, min_impurity_split=1e-7, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, - warm_start=False, presort='auto'): + warm_start=False, presort='auto', categorical='none'): super(GradientBoostingRegressor, self).__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -1803,7 +1827,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, max_features=max_features, min_impurity_split=min_impurity_split, random_state=random_state, alpha=alpha, verbose=verbose, max_leaf_nodes=max_leaf_nodes, warm_start=warm_start, - presort=presort) + presort=presort, categorical=categorical) def predict(self, X): """Predict regression target for X. From 04c5fa95596f5721540797490d460a3ae3eb1bf4 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 29 Nov 2018 14:24:30 +0100 Subject: [PATCH 14/54] compile with recent cython --- sklearn/neighbors/dist_metrics.pyx | 22 ++++++++++------------ sklearn/utils/fast_dict.pxd | 1 + 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/sklearn/neighbors/dist_metrics.pyx b/sklearn/neighbors/dist_metrics.pyx index 6af0441083302..6f23ae3fba75d 100644 --- a/sklearn/neighbors/dist_metrics.pyx +++ b/sklearn/neighbors/dist_metrics.pyx @@ -1097,18 +1097,16 @@ cdef class PyFuncDistance(DistanceMetric): ITYPE_t size) except -1 with gil: cdef np.ndarray x1arr cdef np.ndarray x2arr - with gil: - x1arr = _buffer_to_ndarray(x1, size) - x2arr = _buffer_to_ndarray(x2, size) - d = self.func(x1arr, x2arr, **self.kwargs) - try: - # Cython generates code here that results in a TypeError - # if d is the wrong type. - return d - except TypeError: - raise TypeError("Custom distance function must accept two " - "vectors and return a float.") - + x1arr = _buffer_to_ndarray(x1, size) + x2arr = _buffer_to_ndarray(x2, size) + d = self.func(x1arr, x2arr, **self.kwargs) + try: + # Cython generates code here that results in a TypeError + # if d is the wrong type. + return d + except TypeError: + raise TypeError("Custom distance function must accept two " + "vectors and return a float.") cdef inline double fmax(double a, double b) nogil: diff --git a/sklearn/utils/fast_dict.pxd b/sklearn/utils/fast_dict.pxd index 5893c53ac541f..0b63655b2d591 100644 --- a/sklearn/utils/fast_dict.pxd +++ b/sklearn/utils/fast_dict.pxd @@ -8,6 +8,7 @@ integers, and values float. from libcpp.map cimport map as cpp_map # Import the C-level symbols of numpy +import numpy as np cimport numpy as np DTYPE = np.float64 From 95a0bd25037cdc9335d714704b5cdc78c76f8439 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Tue, 25 Dec 2018 23:01:44 +0100 Subject: [PATCH 15/54] compile goes a bit further --- sklearn/neighbors/quad_tree.pyx | 2 +- sklearn/tree/_splitter.pxd | 10 ++++++++-- sklearn/tree/_splitter.pyx | 9 +++++---- sklearn/tree/_tree.pxd | 6 +++++- sklearn/tree/_utils.pxd | 4 ++-- 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/sklearn/neighbors/quad_tree.pyx b/sklearn/neighbors/quad_tree.pyx index fbe736636c89d..b2a7de0d1ebed 100644 --- a/sklearn/neighbors/quad_tree.pyx +++ b/sklearn/neighbors/quad_tree.pyx @@ -605,7 +605,7 @@ cdef class _QuadTree: else: capacity = 2 * self.capacity - safe_realloc(&self.cells, capacity) + safe_realloc(&self.cells, capacity, sizeof(self.cells[0])) # if capacity smaller than cell_count, adjust the counter if capacity < self.cell_count: diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 035a5dd48ee9d..f4eb05a04c85e 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -1,3 +1,5 @@ +# cython: language_level=3 + # Authors: Gilles Louppe # Peter Prettenhofer # Brian Holt @@ -12,7 +14,7 @@ import numpy as np cimport numpy as np -from ._utils cimport SplitValue +#from ._utils cimport SplitValue from ._criterion cimport Criterion @@ -23,7 +25,11 @@ ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer -cdef struct SplitRecord: +ctypedef union SplitValue: + DOUBLE_t threshold + UINT64_t cat_split + +ctypedef struct SplitRecord: # Data to track sample split SIZE_t feature # Which feature to split on. SIZE_t pos # Split samples array at the given position, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 206964e253bcf..21c24ccb01e69 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -1,3 +1,4 @@ +# cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False # cython: wraparound=False @@ -1544,10 +1545,10 @@ cdef class BestSparseSplitter(BaseSparseSplitter): # sum of halves used to avoid infinite values current.split_value.threshold = Xf[p_prev] / 2.0 + Xf[p] / 2.0 - if ((current.threshold == Xf[p]) or - (current.threshold == INFINITY) or - (current.threshold == -INFINITY)): - current.threshold = Xf[p_prev] + if ((current.split_value.threshold == Xf[p]) or + (current.split_value.threshold == INFINITY) or + (current.split_value.threshold == -INFINITY)): + current.split_value.threshold = Xf[p_prev] best = current diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index bc5a273da14d3..9ec23ee88f896 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -19,7 +19,11 @@ ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer -from ._utils cimport SplitValue +#from ._utils cimport SplitValue +ctypedef union SplitValue: + np.npy_float64 threshold + np.npy_uint64 cat_split + from ._splitter cimport Splitter from ._splitter cimport SplitRecord diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 748bfa5f6e080..12f838eb839f6 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -11,8 +11,8 @@ import numpy as np cimport numpy as np -from _tree cimport Node -from sklearn.neighbors.quad_tree cimport Cell +from ._tree cimport Node +from ..neighbors.quad_tree cimport Cell ctypedef np.npy_float32 DTYPE_t # Type of X ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight From 28f42b15594a8c5b71b8b7ad0161eb00e6282c70 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 26 Dec 2018 12:43:08 +0100 Subject: [PATCH 16/54] compiles --- sklearn/ensemble/forest.py | 2 +- sklearn/tree/_splitter.pxd | 18 +----------------- sklearn/tree/_splitter.pyx | 13 +++++++------ sklearn/tree/_tree.pxd | 22 +++------------------- sklearn/tree/_utils.pxd | 29 +++++++++++++++++++++++++++-- 5 files changed, 39 insertions(+), 45 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 0cfcc6b2bf48e..648c4d0a4e417 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -1286,7 +1286,7 @@ def __init__(self, max_leaf_nodes=None, min_impurity_decrease=0., min_impurity_split=None, - categorical='none' + categorical='none', bootstrap=True, oob_score=False, n_jobs=None, diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index f4eb05a04c85e..928fc934dbeb9 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -14,7 +14,7 @@ import numpy as np cimport numpy as np -#from ._utils cimport SplitValue +from ._utils cimport SplitValue, SplitRecord from ._criterion cimport Criterion @@ -25,22 +25,6 @@ ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer -ctypedef union SplitValue: - DOUBLE_t threshold - UINT64_t cat_split - -ctypedef struct SplitRecord: - # Data to track sample split - SIZE_t feature # Which feature to split on. - SIZE_t pos # Split samples array at the given position, - # i.e. count of samples below threshold for feature. - # pos is >= end if the node is a leaf. - SplitValue split_value # Generalized threshold for categorical and - # non-categorical features - double improvement # Impurity improvement given parent node. - double impurity_left # Impurity of the left split. - double impurity_right # Impurity of the right split. - cdef class Splitter: # The splitter searches in the input space for a feature and a threshold # to split the samples samples[start:end]. diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 21c24ccb01e69..6a1535f5adb8d 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -421,6 +421,7 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t feature_offset cdef SIZE_t i cdef SIZE_t j + cdef UINT64_t ui # unsigned long int i cdef SIZE_t n_visited_features = 0 # Number of features discovered to be constant during the split search @@ -433,8 +434,7 @@ cdef class BestSplitter(BaseDenseSplitter): cdef DTYPE_t current_feature_value cdef SIZE_t partition_end cdef bint is_categorical - cdef UINT64_t cat_idx, cat_split - cdef SIZE_t ncat_present + cdef UINT64_t cat_split, cat_idx, ncat_present cdef INT32_t cat_offs[64] cdef bint breiman_shortcut = self.breiman_shortcut cdef SIZE_t sorted_cat[64] @@ -554,8 +554,8 @@ cdef class BestSplitter(BaseDenseSplitter): break cat_split = 0 - for i in range(cat_idx): - cat_split |= ( 1) << sorted_cat[i] + for ui in range(cat_idx): + cat_split |= ( 1) << sorted_cat[ui] if cat_split & 1: cat_split = (~cat_split) & ( (~( 0)) >> (64 - self.n_categories[current.feature])) @@ -567,8 +567,9 @@ cdef class BestSplitter(BaseDenseSplitter): # We double cat_idx to avoid double-counting equivalent splits # This also ensures that cat_split & 1 == 0 as required cat_split = 0 - for i in range(ncat_present): - cat_split |= ((cat_idx << 1) & (( 1) << i)) << cat_offs[i] + for ui in range(ncat_present): + cat_split |= ((cat_idx << 1) & + (( 1) << ui)) << cat_offs[ui] # Partition j = start diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 9ec23ee88f896..03c1d73b84c0b 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -19,26 +19,10 @@ ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer -#from ._utils cimport SplitValue -ctypedef union SplitValue: - np.npy_float64 threshold - np.npy_uint64 cat_split - - +from ._utils cimport SplitValue +from ._utils cimport SplitRecord +from ._utils cimport Node from ._splitter cimport Splitter -from ._splitter cimport SplitRecord - -cdef struct Node: - # Base storage structure for the nodes in a Tree object - - SIZE_t left_child # id of the left child of the node - SIZE_t right_child # id of the right child of the node - SIZE_t feature # Feature used for splitting the node - SplitValue split_value # Generalized threshold for categorical and - # non-categorical features - DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) - SIZE_t n_node_samples # Number of samples at the node - DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node cdef class CategoryCacheMgr: diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 12f838eb839f6..6ea721a45b6b8 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -11,7 +11,6 @@ import numpy as np cimport numpy as np -from ._tree cimport Node from ..neighbors.quad_tree cimport Cell ctypedef np.npy_float32 DTYPE_t # Type of X @@ -41,7 +40,33 @@ ctypedef union SplitValue: DOUBLE_t threshold UINT64_t cat_split -cdef struct Node # Forward declaration + +ctypedef struct SplitRecord: + # Data to track sample split + SIZE_t feature # Which feature to split on. + SIZE_t pos # Split samples array at the given position, + # i.e. count of samples below threshold for feature. + # pos is >= end if the node is a leaf. + SplitValue split_value # Generalized threshold for categorical and + # non-categorical features + double improvement # Impurity improvement given parent node. + double impurity_left # Impurity of the left split. + double impurity_right # Impurity of the right split. + + +cdef struct Node: + # Base storage structure for the nodes in a Tree object + + SIZE_t left_child # id of the left child of the node + SIZE_t right_child # id of the right child of the node + SIZE_t feature # Feature used for splitting the node + SplitValue split_value # Generalized threshold for categorical and + # non-categorical features + DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) + SIZE_t n_node_samples # Number of samples at the node + DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node + +# cdef struct Node # Forward declaration cdef enum: # Max value for our rand_r replacement (near the bottom). From 724e7209e1f818b92db86990fe729ed5b6eee682 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 26 Dec 2018 12:49:02 +0100 Subject: [PATCH 17/54] fix yield based tests --- sklearn/tree/tests/test_tree.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index c1495df4804ea..83ff1ed26baf0 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1813,7 +1813,8 @@ def _pickle_copy(obj): assert_equal(n_samples, n_samples_) -def check_invalid_categorical(name): +@pytest.mark.parametrize('name', ALL_TREES) +def test_invalid_categorical(name): Tree = ALL_TREES[name] raise_on_construction = ['invalid string', [[0]]] raise_on_fit = [[False, False, False], [1, 2], [-3], [0, 0, 1]] @@ -1823,12 +1824,9 @@ def check_invalid_categorical(name): assert_raises(ValueError, Tree(categorical=catval).fit, X, y) -def test_invalid_categorical(): - for name in ALL_TREES: - yield check_invalid_categorical, name - - -def check_no_sparse_with_categorical(name): +@pytest.mark.parametrize('name', SPARSE_TREES) +def test_no_sparse_with_categorical(name): + # Currently we do not support sparse categorical features X, y, X_sparse = [DATASETS['clf_small'][z] for z in ['X', 'y', 'X_sparse']] Tree = ALL_TREES[name] @@ -1838,12 +1836,6 @@ def check_no_sparse_with_categorical(name): Tree(categorical=[6, 10]).fit(X, y).predict, X_sparse) -def test_no_sparse_with_categorical(): - # Currently we do not support sparse categorical features - for name in SPARSE_TREES: - yield check_no_sparse_with_categorical, name - - def test_empty_leaf_infinite_threshold(): # try to make empty leaf by using near infinite value. data = np.random.RandomState(0).randn(100, 11) * 2e38 From 614e51c7517cb09389255ee0c58b4df4c04092ca Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 26 Dec 2018 13:30:39 +0100 Subject: [PATCH 18/54] add cat_split to NODE_DTYPE --- sklearn/tree/_tree.pyx | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 582dd1f295397..6da406404a0d4 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -78,16 +78,18 @@ cdef SIZE_t INITIAL_STACK_SIZE = 10 ## 'formats': [np.float64, np.uint64], ## 'offsets': [0, 0] ##}) + NODE_DTYPE = np.dtype({ - 'names': ['left_child', 'right_child', 'feature', 'threshold', 'impurity', - 'n_node_samples', 'weighted_n_node_samples'], - 'formats': [np.intp, np.intp, np.intp, np.float64, np.float64, np.intp, - np.float64], + 'names': ['left_child', 'right_child', 'feature', 'threshold', 'cat_split', + 'impurity', 'n_node_samples', 'weighted_n_node_samples'], + 'formats': [np.intp, np.intp, np.intp, np.float64, np.uint64, np.float64, + np.intp, np.float64], 'offsets': [ &( NULL).left_child, &( NULL).right_child, &( NULL).feature, &( NULL).split_value, + &( NULL).split_value, &( NULL).impurity, &( NULL).n_node_samples, &( NULL).weighted_n_node_samples From 51e51cf63b14509c3d43c980be509bef24d0fc0a Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 26 Dec 2018 14:42:46 +0100 Subject: [PATCH 19/54] compare float arrays with almost_equal --- sklearn/tree/tests/test_tree.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 83ff1ed26baf0..08b24c7f796fd 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -177,8 +177,8 @@ def assert_tree_equal(d, s, message): assert_array_equal(d.feature[internal], s.feature[internal], message + ": inequal features") - assert_array_equal(d.threshold[internal], s.threshold[internal], - message + ": inequal threshold") + assert_array_almost_equal(d.threshold[internal], s.threshold[internal], + err_msg=message + ": inequal threshold") assert_array_equal(d.n_node_samples.sum(), s.n_node_samples.sum(), message + ": inequal sum(n_node_samples)") assert_array_equal(d.n_node_samples, s.n_node_samples, @@ -1851,4 +1851,3 @@ def test_empty_leaf_infinite_threshold(): infinite_threshold = np.where(~np.isfinite(tree.tree_.threshold))[0] assert len(infinite_threshold) == 0 assert len(empty_leaf) == 0 - From 78400049a6ec93043afe45bb97693d5e353efd49 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 26 Dec 2018 14:44:41 +0100 Subject: [PATCH 20/54] remove extra import --- sklearn/tree/tree.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 301d01ddd579f..668a302c3bd60 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -22,7 +22,6 @@ from abc import ABCMeta from abc import abstractmethod from math import ceil -import warnings import numpy as np from scipy.sparse import issparse From 3d126607ef41652d13490a84641a98baaa3a130b Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 26 Dec 2018 14:51:04 +0100 Subject: [PATCH 21/54] remove commented extra lines --- sklearn/tree/_tree.pyx | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 6da406404a0d4..b6cbb79ce473e 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -69,16 +69,6 @@ cdef SIZE_t _TREE_LEAF = TREE_LEAF cdef SIZE_t _TREE_UNDEFINED = TREE_UNDEFINED cdef SIZE_t INITIAL_STACK_SIZE = 10 -# Repeat struct definition for numpy -# NOTE: SPLITVALUE_DTYPE cannot be used with numpy < v1.7 -# Recommend replacing threshold with it when support for v1.6 is dropped - -##SPLITVALUE_DTYPE = np.dtype({ -## 'names': ['threshold', 'cat_split'], -## 'formats': [np.float64, np.uint64], -## 'offsets': [0, 0] -##}) - NODE_DTYPE = np.dtype({ 'names': ['left_child', 'right_child', 'feature', 'threshold', 'cat_split', 'impurity', 'n_node_samples', 'weighted_n_node_samples'], From bb02337512cfe49e0ebb79181f3df7b593eb9b8d Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 26 Dec 2018 15:37:05 +0100 Subject: [PATCH 22/54] fix some docstrings --- sklearn/ensemble/forest.py | 26 ++++++-------------------- sklearn/tree/_tree.pyx | 1 - sklearn/tree/tree.py | 25 +++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 648c4d0a4e417..efda3e0cd933f 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -966,17 +966,10 @@ class labels (multi-output problem). ... n_informative=2, n_redundant=0, ... random_state=0, shuffle=False) >>> clf = RandomForestClassifier(n_estimators=100, max_depth=2, - ... random_state=0) - >>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE - RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini', - max_depth=2, max_features='auto', max_leaf_nodes=None, - min_impurity_decrease=0.0, min_impurity_split=None, - min_samples_leaf=1, min_samples_split=2, - min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None, - oob_score=False, random_state=0, verbose=0, warm_start=False) - >>> print(clf.feature_importances_) + ... random_state=0).fit(X, y) + >>> clf.feature_importances_ [0.14205973 0.76664038 0.0282433 0.06305659] - >>> print(clf.predict([[0, 0, 0, 0]])) + >>> clf.predict([[0, 0, 0, 0]]) [1] Notes @@ -1231,17 +1224,10 @@ class RandomForestRegressor(ForestRegressor): >>> X, y = make_regression(n_features=4, n_informative=2, ... random_state=0, shuffle=False) >>> regr = RandomForestRegressor(max_depth=2, random_state=0, - ... n_estimators=100) - >>> regr.fit(X, y) # doctest: +NORMALIZE_WHITESPACE - RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=2, - max_features='auto', max_leaf_nodes=None, - min_impurity_decrease=0.0, min_impurity_split=None, - min_samples_leaf=1, min_samples_split=2, - min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None, - oob_score=False, random_state=0, verbose=0, warm_start=False) - >>> print(regr.feature_importances_) + ... n_estimators=100).fit(X, y) + >>> regr.feature_importances_ [0.18146984 0.81473937 0.00145312 0.00233767] - >>> print(regr.predict([[0, 0, 0, 0]])) + >>> regr.predict([[0, 0, 0, 0]]) [-8.32987858] Notes diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index b6cbb79ce473e..183e3d42c0c01 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1160,7 +1160,6 @@ cdef class Tree: return out - cpdef compute_feature_importances(self, normalize=True): """Computes the importance of each feature (aka variable).""" cdef Node* left diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 668a302c3bd60..fe8601e5a9c63 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -1401,6 +1401,19 @@ class ExtraTreeClassifier(DecisionTreeClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + + .. versionadded:: 0.21 + See also -------- ExtraTreeRegressor, sklearn.ensemble.ExtraTreesClassifier, @@ -1571,6 +1584,18 @@ class ExtraTreeRegressor(DecisionTreeRegressor): Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + + .. versionadded:: 0.21 See also -------- From 56cdb725033ee9da48e6e89288f6de834ee57f34 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 26 Dec 2018 17:31:00 +0100 Subject: [PATCH 23/54] remove overlapping DTYPE --- sklearn/tree/_tree.pyx | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 183e3d42c0c01..78647f7e777c8 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -69,6 +69,9 @@ cdef SIZE_t _TREE_LEAF = TREE_LEAF cdef SIZE_t _TREE_UNDEFINED = TREE_UNDEFINED cdef SIZE_t INITIAL_STACK_SIZE = 10 +""" +this includes cat_split, but it breaks joblib.hash + NODE_DTYPE = np.dtype({ 'names': ['left_child', 'right_child', 'feature', 'threshold', 'cat_split', 'impurity', 'n_node_samples', 'weighted_n_node_samples'], @@ -85,6 +88,23 @@ NODE_DTYPE = np.dtype({ &( NULL).weighted_n_node_samples ] }) +""" + +NODE_DTYPE = np.dtype({ + 'names': ['left_child', 'right_child', 'feature', 'threshold', + 'impurity', 'n_node_samples', 'weighted_n_node_samples'], + 'formats': [np.intp, np.intp, np.intp, np.float64, np.float64, + np.intp, np.float64], + 'offsets': [ + &( NULL).left_child, + &( NULL).right_child, + &( NULL).feature, + &( NULL).split_value, + &( NULL).impurity, + &( NULL).n_node_samples, + &( NULL).weighted_n_node_samples + ] +}) # ============================================================================= # TreeBuilder From d49ff0e5261b6e89d36bc8c402af7b3ea35b2d68 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 26 Dec 2018 18:01:40 +0100 Subject: [PATCH 24/54] fix forest doctest --- sklearn/ensemble/forest.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index efda3e0cd933f..3d864a2e4544a 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -873,6 +873,8 @@ class RandomForestClassifier(ForestClassifier): labels using the ``Gini`` or ``Entropy`` criteria. In this case, the runtime is linear in the number of categories. + .. versionadded:: 0.21 + bootstrap : boolean, optional (default=True) Whether bootstrap samples are used when building trees. @@ -968,9 +970,9 @@ class labels (multi-output problem). >>> clf = RandomForestClassifier(n_estimators=100, max_depth=2, ... random_state=0).fit(X, y) >>> clf.feature_importances_ - [0.14205973 0.76664038 0.0282433 0.06305659] + array([0.14205973, 0.76664038, 0.0282433 , 0.06305659]) >>> clf.predict([[0, 0, 0, 0]]) - [1] + array([1]) Notes ----- @@ -1169,6 +1171,8 @@ class RandomForestRegressor(ForestRegressor): labels using the ``Gini`` or ``Entropy`` criteria. In this case, the runtime is linear in the number of categories. + .. versionadded:: 0.21 + bootstrap : boolean, optional (default=True) Whether bootstrap samples are used when building trees. @@ -1226,9 +1230,9 @@ class RandomForestRegressor(ForestRegressor): >>> regr = RandomForestRegressor(max_depth=2, random_state=0, ... n_estimators=100).fit(X, y) >>> regr.feature_importances_ - [0.18146984 0.81473937 0.00145312 0.00233767] + array([0.18146984, 0.81473937, 0.00145312, 0.00233767]) >>> regr.predict([[0, 0, 0, 0]]) - [-8.32987858] + array([-8.32987858]) Notes ----- From 081dd830501521dab407377cd4d46779f5206ba2 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 26 Dec 2018 18:15:27 +0100 Subject: [PATCH 25/54] remove input validation from __init__, it's don in fit --- sklearn/tree/tree.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index fe8601e5a9c63..b2171572d58f8 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -110,15 +110,6 @@ def __init__(self, self.presort = presort self.categorical = categorical - # Input validation for parameter categorical - if isinstance(self.categorical, str): - if categorical not in ('all', 'none'): - raise ValueError("Invalid value for categorical: {}. Allowed" - " strings are 'all' or 'none'" - "".format(categorical)) - elif len(np.shape(categorical)) != 1: - raise ValueError("Invalid shape for categorical") - def get_depth(self): """Returns the depth of the decision tree. From 51b908cd674c70a7244bdeb95433bc8e08341228 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 26 Dec 2018 18:29:04 +0100 Subject: [PATCH 26/54] fix extra tree param docstring --- sklearn/ensemble/forest.py | 4 ++++ sklearn/tree/tree.py | 20 ++++++-------------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 3d864a2e4544a..177fe137b03a5 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -1425,6 +1425,8 @@ class ExtraTreesClassifier(ForestClassifier): have an upper limit of :math:`2^{31}` categories, and runtimes linear in the number of categories. + .. versionadded:: 0.21 + bootstrap : boolean, optional (default=False) Whether bootstrap samples are used when building trees. @@ -1695,6 +1697,8 @@ class ExtraTreesRegressor(ForestRegressor): have an upper limit of :math:`2^{31}` categories, and runtimes linear in the number of categories. + .. versionadded:: 0.21 + bootstrap : boolean, optional (default=False) Whether bootstrap samples are used when building trees. diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index b2171572d58f8..34de67ffbb8ec 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -1395,13 +1395,9 @@ class ExtraTreeClassifier(DecisionTreeClassifier): categorical : array-like or str Array of feature indices, boolean array of length n_features, ``'all'`` or ``'none'``. Indicates which features should be - considered as categorical rather than ordinal. For decision trees, - the maximum number of categories is 64. In practice, the limit will - often be lower because the process of searching for the best possible - split grows exponentially with the number of categories. However, a - shortcut due to Breiman (1984) is used when fitting data with binary - labels using the ``Gini`` or ``Entropy`` criteria. In this case, - the runtime is linear in the number of categories. + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. .. versionadded:: 0.21 @@ -1578,13 +1574,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor): categorical : array-like or str Array of feature indices, boolean array of length n_features, ``'all'`` or ``'none'``. Indicates which features should be - considered as categorical rather than ordinal. For decision trees, - the maximum number of categories is 64. In practice, the limit will - often be lower because the process of searching for the best possible - split grows exponentially with the number of categories. However, a - shortcut due to Breiman (1984) is used when fitting data with binary - labels using the ``Gini`` or ``Entropy`` criteria. In this case, - the runtime is linear in the number of categories. + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. .. versionadded:: 0.21 From 45d1f3386e7e2f4d0f703ba86d6b479b3213dd09 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 26 Dec 2018 23:01:53 +0100 Subject: [PATCH 27/54] improve tests for invalid categorical input --- sklearn/tree/tests/test_tree.py | 15 +++++++-------- sklearn/tree/tree.py | 16 +++++++++------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 08b24c7f796fd..0d5a7994365db 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1813,15 +1813,14 @@ def _pickle_copy(obj): assert_equal(n_samples, n_samples_) -@pytest.mark.parametrize('name', ALL_TREES) -def test_invalid_categorical(name): +@pytest.mark.parametrize('name', ALL_TREES.keys()) +@pytest.mark.parametrize('categorical', ['invalid string', [[0]], + [False, False, False], [1, 2], [-3], + [0, 0, 1]]) +def test_invalid_categorical(name, categorical): Tree = ALL_TREES[name] - raise_on_construction = ['invalid string', [[0]]] - raise_on_fit = [[False, False, False], [1, 2], [-3], [0, 0, 1]] - for catval in raise_on_construction: - assert_raises(ValueError, Tree, categorical=catval) - for catval in raise_on_fit: - assert_raises(ValueError, Tree(categorical=catval).fit, X, y) + with pytest.raises(ValueError, match="Invalid value for categorical"): + Tree(categorical=categorical).fit(X, y) @pytest.mark.parametrize('name', SPARSE_TREES) diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index 34de67ffbb8ec..77f7674dbde60 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -291,7 +291,6 @@ def fit(self, X, y, sample_weight=None, check_input=True, elif self.categorical == 'all': categorical = np.arange(self.n_features_) else: - # Should have been caught in the constructor, but just in case raise ValueError("Invalid value for categorical: {}. Allowed" " strings are 'all' or 'none'" "".format(self.categorical)) @@ -299,23 +298,26 @@ def fit(self, X, y, sample_weight=None, check_input=True, categorical = np.atleast_1d(self.categorical).flatten() if categorical.dtype == np.bool: if categorical.size != self.n_features_: - raise ValueError("Shape of boolean parameter categorical must" - " be (n_features,)") + raise ValueError("Invalid value for categorical: Shape of " + "boolean parameter categorical must " + "be (n_features,)") categorical = np.nonzero(categorical)[0] if (np.size(categorical) > self.n_features_ or (categorical.size > 0 and (categorical.min() < 0 or categorical.max() >= self.n_features_))): - raise ValueError("Invalid shape or invalid feature index for" - " parameter categorical") + raise ValueError("Invalid value for categorical: Invalid shape or " + "feature index for parameter categorical " + "invalid.") if issparse(X): if categorical.size > 0: raise NotImplementedError("Categorical features not supported" " with sparse inputs") else: if np.any(X[:, categorical].astype(np.int) < 0): - raise ValueError("Invalid training data: categorical values" - " must be non-negative.") + raise ValueError("Invalid value for categorical: given values " + "for categorical features must be " + "non-negative.") # Calculate n_categories and verify they are all at least 1% populated n_categories = np.array([np.int(X[:, i].max()) + 1 if i in categorical From 4f1f36066417f019d5d75fa8c7bc3a2b08dc8cfd Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 27 Dec 2018 11:24:17 +0100 Subject: [PATCH 28/54] use pytest.raises instead in invalid categorical sparse test --- sklearn/tree/tests/test_tree.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 0d5a7994365db..b3fb1dc64b979 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1813,7 +1813,7 @@ def _pickle_copy(obj): assert_equal(n_samples, n_samples_) -@pytest.mark.parametrize('name', ALL_TREES.keys()) +@pytest.mark.parametrize('name', ALL_TREES) @pytest.mark.parametrize('categorical', ['invalid string', [[0]], [False, False, False], [1, 2], [-3], [0, 0, 1]]) @@ -1823,16 +1823,19 @@ def test_invalid_categorical(name, categorical): Tree(categorical=categorical).fit(X, y) -@pytest.mark.parametrize('name', SPARSE_TREES) +@pytest.mark.parametrize('name', ALL_TREES) def test_no_sparse_with_categorical(name): # Currently we do not support sparse categorical features X, y, X_sparse = [DATASETS['clf_small'][z] for z in ['X', 'y', 'X_sparse']] Tree = ALL_TREES[name] - assert_raises(NotImplementedError, Tree(categorical=[6, 10]).fit, - X_sparse, y) - assert_raises(NotImplementedError, - Tree(categorical=[6, 10]).fit(X, y).predict, X_sparse) + with pytest.raises(NotImplementedError, + match="Categorical features not supported with sparse"): + Tree(categorical=[6, 10]).fit(X_sparse, y) + + with pytest.raises(NotImplementedError, + match="Categorical features not supported with sparse"): + Tree(categorical=[6, 10]).fit(X, y).predict(X_sparse) def test_empty_leaf_infinite_threshold(): From ea263ead55844452e6faa5e4759e7d841be2bc06 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 27 Dec 2018 11:38:44 +0100 Subject: [PATCH 29/54] add cython code coverage to see uncovered code. --- .coveragerc | 1 + 1 file changed, 1 insertion(+) diff --git a/.coveragerc b/.coveragerc index 6d76a5bca8235..7be9478b845d1 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,6 +2,7 @@ branch = True source = sklearn include = */sklearn/* +plugins = Cython.Coverage omit = */sklearn/externals/* */benchmarks/* From e173dee0ef23ad1ba707dbd41637be1cd0166a48 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 27 Dec 2018 11:42:49 +0100 Subject: [PATCH 30/54] add language_level and profiling directives to cython files --- sklearn/tree/_criterion.pxd | 4 ++++ sklearn/tree/_criterion.pyx | 3 +++ sklearn/tree/_splitter.pxd | 2 ++ sklearn/tree/_splitter.pyx | 2 ++ sklearn/tree/_tree.pxd | 4 ++++ sklearn/tree/_tree.pyx | 3 +++ sklearn/tree/_utils.pxd | 4 ++++ sklearn/tree/_utils.pyx | 3 +++ 8 files changed, 25 insertions(+) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 1cbd395af8e37..3495198439660 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -1,3 +1,7 @@ +# cython: linetrace=True +# distutils: define_macros=CYTHON_TRACE_NOGIL=1 +# cython: language_level=3 + # Authors: Gilles Louppe # Peter Prettenhofer # Brian Holt diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index fd0dbd5153a21..2fa467fec4048 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1,3 +1,6 @@ +# cython: linetrace=True +# distutils: define_macros=CYTHON_TRACE_NOGIL=1 +# cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False # cython: wraparound=False diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 928fc934dbeb9..d3389ed77243a 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -1,3 +1,5 @@ +# cython: linetrace=True +# distutils: define_macros=CYTHON_TRACE_NOGIL=1 # cython: language_level=3 # Authors: Gilles Louppe diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 6a1535f5adb8d..1b15780862f4d 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -1,3 +1,5 @@ +# cython: linetrace=True +# distutils: define_macros=CYTHON_TRACE_NOGIL=1 # cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 03c1d73b84c0b..6cfcdeefb2e56 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -1,3 +1,7 @@ +# cython: linetrace=True +# distutils: define_macros=CYTHON_TRACE_NOGIL=1 +# cython: language_level=3 + # Authors: Gilles Louppe # Peter Prettenhofer # Brian Holt diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 78647f7e777c8..f314ce0327d1b 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1,3 +1,6 @@ +# cython: linetrace=True +# distutils: define_macros=CYTHON_TRACE_NOGIL=1 +# cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False # cython: wraparound=False diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 6ea721a45b6b8..db845e4be4eca 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -1,3 +1,7 @@ +# cython: linetrace=True +# distutils: define_macros=CYTHON_TRACE_NOGIL=1 +# cython: language_level=3 + # Authors: Gilles Louppe # Peter Prettenhofer # Arnaud Joly diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 383b0b4a8d848..5b3f5e04e888e 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -1,3 +1,6 @@ +# cython: linetrace=True +# distutils: define_macros=CYTHON_TRACE_NOGIL=1 +# cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False # cython: wraparound=False From a4400dbaa0eb3ed010eccd9865b460f2a256e9e2 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 27 Dec 2018 17:50:18 +0100 Subject: [PATCH 31/54] revert linetrace directive --- sklearn/tree/_criterion.pxd | 2 -- sklearn/tree/_criterion.pyx | 2 -- sklearn/tree/_splitter.pxd | 2 -- sklearn/tree/_splitter.pyx | 2 -- sklearn/tree/_tree.pxd | 2 -- sklearn/tree/_tree.pyx | 2 -- sklearn/tree/_utils.pxd | 2 -- sklearn/tree/_utils.pyx | 2 -- 8 files changed, 16 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 3495198439660..7d2802487f416 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -1,5 +1,3 @@ -# cython: linetrace=True -# distutils: define_macros=CYTHON_TRACE_NOGIL=1 # cython: language_level=3 # Authors: Gilles Louppe diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 2fa467fec4048..cceb358e94f2b 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1,5 +1,3 @@ -# cython: linetrace=True -# distutils: define_macros=CYTHON_TRACE_NOGIL=1 # cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index d3389ed77243a..928fc934dbeb9 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -1,5 +1,3 @@ -# cython: linetrace=True -# distutils: define_macros=CYTHON_TRACE_NOGIL=1 # cython: language_level=3 # Authors: Gilles Louppe diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 1b15780862f4d..6a1535f5adb8d 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -1,5 +1,3 @@ -# cython: linetrace=True -# distutils: define_macros=CYTHON_TRACE_NOGIL=1 # cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 6cfcdeefb2e56..08ac5f33f2d08 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -1,5 +1,3 @@ -# cython: linetrace=True -# distutils: define_macros=CYTHON_TRACE_NOGIL=1 # cython: language_level=3 # Authors: Gilles Louppe diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index f314ce0327d1b..a05079e17e49e 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1,5 +1,3 @@ -# cython: linetrace=True -# distutils: define_macros=CYTHON_TRACE_NOGIL=1 # cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index db845e4be4eca..e88c27d53f06a 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -1,5 +1,3 @@ -# cython: linetrace=True -# distutils: define_macros=CYTHON_TRACE_NOGIL=1 # cython: language_level=3 # Authors: Gilles Louppe diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 5b3f5e04e888e..b25d516118127 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -1,5 +1,3 @@ -# cython: linetrace=True -# distutils: define_macros=CYTHON_TRACE_NOGIL=1 # cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False From b88cef29bcd4c17b98903ab39285fb404f368e32 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 28 Dec 2018 12:16:06 +0100 Subject: [PATCH 32/54] revert coveragerc cython support --- .coveragerc | 1 - 1 file changed, 1 deletion(-) diff --git a/.coveragerc b/.coveragerc index 7be9478b845d1..6d76a5bca8235 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,7 +2,6 @@ branch = True source = sklearn include = */sklearn/* -plugins = Cython.Coverage omit = */sklearn/externals/* */benchmarks/* From c9b263c59f3cfa7da4753a8c2d41801af54291e4 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sat, 12 Jan 2019 15:33:35 +0100 Subject: [PATCH 33/54] benchmark added --- benchmarks/bench_tree_nocats.py | 102 ++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 benchmarks/bench_tree_nocats.py diff --git a/benchmarks/bench_tree_nocats.py b/benchmarks/bench_tree_nocats.py new file mode 100644 index 0000000000000..b26763c6e42d4 --- /dev/null +++ b/benchmarks/bench_tree_nocats.py @@ -0,0 +1,102 @@ +import sys + +from timeit import timeit +from itertools import product +import numpy as np +import pandas as pd + +from sklearn.preprocessing import OneHotEncoder +from sklearn.model_selection import StratifiedKFold +from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier +from sklearn.metrics import roc_auc_score +from sklearn.datasets import fetch_openml + + +def get_data(trunc_ncat): + # the data is located here: https://www.openml.org/d/4135 + data = fetch_openml(data_id=4135) + X = pd.DataFrame(data.data) + y = data.target + + Xdicts = [] + for trunc in trunc_ncat: + X_trunc = X % trunc if trunc > 0 else X + keep_idx = np.array([idx[0] for idx in + X_trunc.groupby(list(X.columns)).groups.values()]) + X_trunc = X_trunc.values[keep_idx] + y_trunc = y[keep_idx] + + X_ohe = OneHotEncoder(categories='auto').fit_transform(X_trunc) + + Xdicts.append({'X': X_trunc, 'y': y_trunc, 'ohe': False, + 'trunc': trunc}) + Xdicts.append({'X': X_ohe, 'y': y_trunc, 'ohe': True, + 'trunc': trunc}) + + return Xdicts + + +# Training dataset +trunc_factor = [4, 6, 8, 10, 12, 14, 16, 0] +data = get_data(trunc_factor) + +for bleh in range(1): + outfile = sys.stdout + + # Loop over classifiers and datasets + for Xydict, clf_type in product( + data, [RandomForestClassifier, ExtraTreesClassifier]): + + # Can't use non-truncated categorical data with RandomForest + if (clf_type is RandomForestClassifier and + not Xydict['ohe'] and not Xydict['trunc']): + continue + + X = Xydict['X'] + y = Xydict['y'] + tech = 'One-hot' if Xydict['ohe'] else 'NOCATS' + trunc = ('truncated({})'.format(Xydict['trunc']) if Xydict['trunc'] > 0 + else 'full') + cat = 'none' if Xydict['ohe'] else 'all' + cv = StratifiedKFold(n_splits=5, shuffle=True, + random_state=17).split(X, y) + + traintimes = [] + testtimes = [] + for train, test in cv: + # Train + clf = clf_type(n_estimators=10, max_features=None, + min_samples_leaf=1, random_state=23, + bootstrap=False, max_depth=None, + categorical=cat) + + traintimes.append(timeit( + "clf.fit(X[train], y[train])".format(cat), + 'from __main__ import clf, X, y, train', number=1)) + + # Check that all leaf nodes are pure + for est in clf.estimators_: + leaves = est.tree_.children_left < 0 + print(np.max(est.tree_.impurity[leaves])) + #assert(np.all(est.tree_.impurity[leaves] == 0)) + + # Test + probs = [] + testtimes.append(timeit( + 'probs.append(clf.predict_proba(X[test]))', + 'from __main__ import probs, clf, X, test', number=1)) + + print('({}, {}, {}) AUC: {}'.format( + clf_type.__name__, trunc, tech, + roc_auc_score(y[test], probs[0][:, 1])), file=outfile) + + traintimes = np.array(traintimes) + testtimes = np.array(testtimes) + print('({}, {}, {}) min/mean/max train times: {} {} {}'.format( + clf_type.__name__, trunc, tech, + traintimes.min(), traintimes.mean(), traintimes.max()), + file=outfile) + print('({}, {}, {}) min/mean/max test times: {} {} {}'.format( + clf_type.__name__, trunc, tech, + testtimes.min(), testtimes.mean(), testtimes.max()), file=outfile) + print(file=outfile) From 6ae188d34a62941a714251b4a40ef27644291a1a Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sat, 12 Jan 2019 15:57:36 +0100 Subject: [PATCH 34/54] some benchmark cleanup --- benchmarks/bench_tree_nocats.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/benchmarks/bench_tree_nocats.py b/benchmarks/bench_tree_nocats.py index b26763c6e42d4..f9004d570d63e 100644 --- a/benchmarks/bench_tree_nocats.py +++ b/benchmarks/bench_tree_nocats.py @@ -14,9 +14,8 @@ def get_data(trunc_ncat): # the data is located here: https://www.openml.org/d/4135 - data = fetch_openml(data_id=4135) - X = pd.DataFrame(data.data) - y = data.target + X, y = fetch_openml(data_id=4135, return_X_y=True) + X = pd.DataFrame(X) Xdicts = [] for trunc in trunc_ncat: @@ -37,7 +36,8 @@ def get_data(trunc_ncat): # Training dataset -trunc_factor = [4, 6, 8, 10, 12, 14, 16, 0] +# trunc_factor = [4, 6, 8, 10, 12, 14, 16, 0] +trunc_factor = [4, 16, 0] data = get_data(trunc_factor) for bleh in range(1): @@ -52,8 +52,7 @@ def get_data(trunc_ncat): not Xydict['ohe'] and not Xydict['trunc']): continue - X = Xydict['X'] - y = Xydict['y'] + X, y = Xydict['X'], Xydict['y'] tech = 'One-hot' if Xydict['ohe'] else 'NOCATS' trunc = ('truncated({})'.format(Xydict['trunc']) if Xydict['trunc'] > 0 else 'full') @@ -74,11 +73,13 @@ def get_data(trunc_ncat): "clf.fit(X[train], y[train])".format(cat), 'from __main__ import clf, X, y, train', number=1)) + """ # Check that all leaf nodes are pure for est in clf.estimators_: leaves = est.tree_.children_left < 0 print(np.max(est.tree_.impurity[leaves])) #assert(np.all(est.tree_.impurity[leaves] == 0)) + """ # Test probs = [] From 9bae7d0d38bb045a293db94bf676d024c1b93156 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sat, 12 Jan 2019 17:28:38 +0100 Subject: [PATCH 35/54] more benchmark cleanup --- benchmarks/bench_tree_nocats.py | 128 ++++++++++++++++---------------- 1 file changed, 64 insertions(+), 64 deletions(-) diff --git a/benchmarks/bench_tree_nocats.py b/benchmarks/bench_tree_nocats.py index f9004d570d63e..886d6e81fcf0c 100644 --- a/benchmarks/bench_tree_nocats.py +++ b/benchmarks/bench_tree_nocats.py @@ -36,68 +36,68 @@ def get_data(trunc_ncat): # Training dataset -# trunc_factor = [4, 6, 8, 10, 12, 14, 16, 0] -trunc_factor = [4, 16, 0] +trunc_factor = [4, 6, 8, 10, 12, 14, 16, 64, 0] data = get_data(trunc_factor) - -for bleh in range(1): - outfile = sys.stdout - - # Loop over classifiers and datasets - for Xydict, clf_type in product( - data, [RandomForestClassifier, ExtraTreesClassifier]): - - # Can't use non-truncated categorical data with RandomForest - if (clf_type is RandomForestClassifier and - not Xydict['ohe'] and not Xydict['trunc']): - continue - - X, y = Xydict['X'], Xydict['y'] - tech = 'One-hot' if Xydict['ohe'] else 'NOCATS' - trunc = ('truncated({})'.format(Xydict['trunc']) if Xydict['trunc'] > 0 - else 'full') - cat = 'none' if Xydict['ohe'] else 'all' - cv = StratifiedKFold(n_splits=5, shuffle=True, - random_state=17).split(X, y) - - traintimes = [] - testtimes = [] - for train, test in cv: - # Train - clf = clf_type(n_estimators=10, max_features=None, - min_samples_leaf=1, random_state=23, - bootstrap=False, max_depth=None, - categorical=cat) - - traintimes.append(timeit( - "clf.fit(X[train], y[train])".format(cat), - 'from __main__ import clf, X, y, train', number=1)) - - """ - # Check that all leaf nodes are pure - for est in clf.estimators_: - leaves = est.tree_.children_left < 0 - print(np.max(est.tree_.impurity[leaves])) - #assert(np.all(est.tree_.impurity[leaves] == 0)) - """ - - # Test - probs = [] - testtimes.append(timeit( - 'probs.append(clf.predict_proba(X[test]))', - 'from __main__ import probs, clf, X, test', number=1)) - - print('({}, {}, {}) AUC: {}'.format( - clf_type.__name__, trunc, tech, - roc_auc_score(y[test], probs[0][:, 1])), file=outfile) - - traintimes = np.array(traintimes) - testtimes = np.array(testtimes) - print('({}, {}, {}) min/mean/max train times: {} {} {}'.format( - clf_type.__name__, trunc, tech, - traintimes.min(), traintimes.mean(), traintimes.max()), - file=outfile) - print('({}, {}, {}) min/mean/max test times: {} {} {}'.format( - clf_type.__name__, trunc, tech, - testtimes.min(), testtimes.mean(), testtimes.max()), file=outfile) - print(file=outfile) +results = [] +# Loop over classifiers and datasets +for Xydict, clf_type in product( + data, [RandomForestClassifier, ExtraTreesClassifier]): + + # Can't use non-truncated categorical data with RandomForest + if (clf_type is RandomForestClassifier and + not Xydict['ohe'] and not Xydict['trunc']): + continue + + X, y = Xydict['X'], Xydict['y'] + tech = 'One-hot' if Xydict['ohe'] else 'NOCATS' + trunc = ('truncated({})'.format(Xydict['trunc']) if Xydict['trunc'] > 0 + else 'full') + cat = 'none' if Xydict['ohe'] else 'all' + cv = StratifiedKFold(n_splits=5, shuffle=True, + random_state=17).split(X, y) + + traintimes = [] + testtimes = [] + aucs = [] + name = '({}, {}, {})'.format(clf_type.__name__, trunc, tech) + + for train, test in cv: + # Train + clf = clf_type(n_estimators=10, max_features=None, + min_samples_leaf=1, random_state=23, + bootstrap=False, max_depth=None, + categorical=cat) + + traintimes.append(timeit( + "clf.fit(X[train], y[train])".format(cat), + 'from __main__ import clf, X, y, train', number=1)) + + """ + # Check that all leaf nodes are pure + for est in clf.estimators_: + leaves = est.tree_.children_left < 0 + print(np.max(est.tree_.impurity[leaves])) + #assert(np.all(est.tree_.impurity[leaves] == 0)) + """ + + # Test + probs = [] + testtimes.append(timeit( + 'probs.append(clf.predict_proba(X[test]))', + 'from __main__ import probs, clf, X, test', number=1)) + + aucs.append(roc_auc_score(y[test], probs[0][:, 1])) + + traintimes = np.array(traintimes) + testtimes = np.array(testtimes) + aucs = np.array(aucs) + results.append([name, traintimes.mean(), traintimes.std(), + testtimes.mean(), testtimes.std(), + aucs.mean(), aucs.std()]) + +results = pd.DataFrame(results) +results.columns = ['name', 'train time mean', 'train time std', + 'test time mean', 'test time std', + 'auc mean', 'auc std'] +results = results.set_index('name') +print(results) From b1bd2d7d5d3cec1cddad9c790c98c82bf5015d41 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sat, 12 Jan 2019 17:36:44 +0100 Subject: [PATCH 36/54] remove extra import --- benchmarks/bench_tree_nocats.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/benchmarks/bench_tree_nocats.py b/benchmarks/bench_tree_nocats.py index 886d6e81fcf0c..f6e212a703da5 100644 --- a/benchmarks/bench_tree_nocats.py +++ b/benchmarks/bench_tree_nocats.py @@ -1,5 +1,3 @@ -import sys - from timeit import timeit from itertools import product import numpy as np From 05d2985f40767ea425e672c80d77700544056680 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 13 Jan 2019 19:49:30 +0100 Subject: [PATCH 37/54] more benchmark touches --- benchmarks/bench_tree_nocats.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/benchmarks/bench_tree_nocats.py b/benchmarks/bench_tree_nocats.py index f6e212a703da5..5bcdb456a0077 100644 --- a/benchmarks/bench_tree_nocats.py +++ b/benchmarks/bench_tree_nocats.py @@ -34,7 +34,7 @@ def get_data(trunc_ncat): # Training dataset -trunc_factor = [4, 6, 8, 10, 12, 14, 16, 64, 0] +trunc_factor = [2, 3, 4, 5, 6, 8, 10, 12, 14, 16, 64, 0] data = get_data(trunc_factor) results = [] # Loop over classifiers and datasets @@ -42,8 +42,10 @@ def get_data(trunc_ncat): data, [RandomForestClassifier, ExtraTreesClassifier]): # Can't use non-truncated categorical data with RandomForest + # and it becomes intractable with too many categories if (clf_type is RandomForestClassifier and - not Xydict['ohe'] and not Xydict['trunc']): + not Xydict['ohe'] and + (not Xydict['trunc'] or Xydict['trunc'] > 16)): continue X, y = Xydict['X'], Xydict['y'] @@ -93,9 +95,9 @@ def get_data(trunc_ncat): testtimes.mean(), testtimes.std(), aucs.mean(), aucs.std()]) -results = pd.DataFrame(results) -results.columns = ['name', 'train time mean', 'train time std', - 'test time mean', 'test time std', - 'auc mean', 'auc std'] -results = results.set_index('name') -print(results) + results_df = pd.DataFrame(results) + results_df.columns = ['name', 'train time mean', 'train time std', + 'test time mean', 'test time std', + 'auc mean', 'auc std'] + results_df = results_df.set_index('name') + print(results_df) From 428206cd90a908f36a6804273a3fb5d99ec6b93a Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 13 Jan 2019 20:13:40 +0100 Subject: [PATCH 38/54] pep8 --- benchmarks/bench_tree_nocats.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_tree_nocats.py b/benchmarks/bench_tree_nocats.py index 5bcdb456a0077..76a2696e999a6 100644 --- a/benchmarks/bench_tree_nocats.py +++ b/benchmarks/bench_tree_nocats.py @@ -97,7 +97,7 @@ def get_data(trunc_ncat): results_df = pd.DataFrame(results) results_df.columns = ['name', 'train time mean', 'train time std', - 'test time mean', 'test time std', - 'auc mean', 'auc std'] + 'test time mean', 'test time std', + 'auc mean', 'auc std'] results_df = results_df.set_index('name') print(results_df) From 2974d5aa047924300300a05436b48afd826e6ab4 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 13 Jan 2019 20:25:32 +0100 Subject: [PATCH 39/54] pep8 --- benchmarks/bench_tree_nocats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench_tree_nocats.py b/benchmarks/bench_tree_nocats.py index 76a2696e999a6..803e85f87318c 100644 --- a/benchmarks/bench_tree_nocats.py +++ b/benchmarks/bench_tree_nocats.py @@ -44,7 +44,7 @@ def get_data(trunc_ncat): # Can't use non-truncated categorical data with RandomForest # and it becomes intractable with too many categories if (clf_type is RandomForestClassifier and - not Xydict['ohe'] and + not Xydict['ohe'] and (not Xydict['trunc'] or Xydict['trunc'] > 16)): continue From 286b04bbf7d8eae8ebc2137fdd66c5046457c89a Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 17 Jan 2019 17:05:46 +0100 Subject: [PATCH 40/54] add some tests --- sklearn/tree/tests/test_tree.py | 91 +++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index b3fb1dc64b979..19bc740c0d1e2 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1853,3 +1853,94 @@ def test_empty_leaf_infinite_threshold(): infinite_threshold = np.where(~np.isfinite(tree.tree_.threshold))[0] assert len(infinite_threshold) == 0 assert len(empty_leaf) == 0 + + +def _make_categorical(n_rows: int, n_numerical: int, n_categorical: int, + cat_size: int, n_num_meaningful: int, + n_cat_meaningful: int, regression: bool, + return_tuple: bool, random_state: int): + + from sklearn.preprocessing import OneHotEncoder + np.random.seed(random_state) + numeric = np.random.standard_normal((n_rows, n_numerical)) + categorical = np.random.randint(0, cat_size, (n_rows, n_categorical)) + categorical_ohe = OneHotEncoder(categories='auto').fit_transform( + categorical[:, :n_cat_meaningful]) + + data_meaningful = np.hstack((numeric[:, :n_num_meaningful], + categorical_ohe.todense())) + _, cols = data_meaningful.shape + coefs = np.random.standard_normal(cols) + y = np.dot(data_meaningful, coefs) + y = np.asarray(y).reshape(-1) + X = np.hstack((numeric, categorical)) + + if not regression: + y = (y < y.mean()).astype(int) + + meaningful_features = np.r_[np.arange(n_num_meaningful), + np.arange(n_cat_meaningful) + + n_numerical] + + if return_tuple: + return X, y, meaningful_features + else: + return {'X': X, + 'y': y, + 'meaningful_features': meaningful_features} + + +@pytest.mark.parametrize('model', ALL_TREES) +@pytest.mark.parametrize('data_params', [ + {'n_rows': 1000, + 'n_numerical': 5, + 'n_categorical': 5, + 'cat_size': 3, + 'n_num_meaningful': 2, + 'n_cat_meaningful': 3}, + {'n_rows': 1000, + 'n_numerical': 0, + 'n_categorical': 5, + 'cat_size': 3, + 'n_num_meaningful': 0, + 'n_cat_meaningful': 3}, + {'n_rows': 1000, + 'n_numerical': 5, + 'n_categorical': 5, + 'cat_size': 64, + 'n_num_meaningful': 1, + 'n_cat_meaningful': 2}, + {'n_rows': 1000, + 'n_numerical': 5, + 'n_categorical': 5, + 'cat_size': 3, + 'n_num_meaningful': 0, + 'n_cat_meaningful': 3}]) +def test_categorical_data(model, data_params): + # DecisionTrees are too slow for large category sizes. + if data_params['cat_size'] > 8 and 'DecisionTree' in model: + pass + + X, y, meaningful_features = _make_categorical( + **data_params, + regression=model in REG_TREES, + return_tuple=True, + random_state=42) + rows, cols = X.shape + categorical_features = (np.arange(data_params['n_categorical']) + + data_params['n_numerical']) + + model = ALL_TREES[model](random_state=42, + categorical=categorical_features).fit(X, y) + fi = model.feature_importances_ + bad_features = np.array([True]*cols) + bad_features[meaningful_features] = False + + good_ones = fi[meaningful_features] + bad_ones = fi[bad_features] + + # all good features should be more important than all bad features. + assert np.all([np.all(x > bad_ones) for x in good_ones]) + + leaves = model.tree_.children_left < 0 + assert(np.all(model.tree_.impurity[leaves] < 1e-6)) From 90d7365bc14943b35e2ab918dfc2fa469b72da0c Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sat, 19 Jan 2019 11:18:37 +0100 Subject: [PATCH 41/54] tests too hard --- sklearn/tree/tests/test_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 19bc740c0d1e2..335d007231aef 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1908,7 +1908,7 @@ def _make_categorical(n_rows: int, n_numerical: int, n_categorical: int, 'n_numerical': 5, 'n_categorical': 5, 'cat_size': 64, - 'n_num_meaningful': 1, + 'n_num_meaningful': 0, 'n_cat_meaningful': 2}, {'n_rows': 1000, 'n_numerical': 5, From 475ea7bda2e53761fc7db790c524d676440eb6a6 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 21 Jan 2019 13:53:23 +0100 Subject: [PATCH 42/54] add some forest tests --- sklearn/ensemble/tests/test_forest.py | 57 +++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index e4eb282923a29..3c6be3635ae48 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -53,6 +53,7 @@ from sklearn.utils.fixes import comb from sklearn.tree.tree import SPARSE_SPLITTERS +from sklearn.tree.tests.test_tree import _make_categorical # toy sample @@ -1337,3 +1338,59 @@ def test_backend_respected(): clf.predict_proba(X) assert ba.count == 0 + + +@pytest.mark.parametrize('model', FOREST_CLASSIFIERS_REGRESSORS) +@pytest.mark.parametrize('data_params', [ + {'n_rows': 10000, + 'n_numerical': 10, + 'n_categorical': 5, + 'cat_size': 3, + 'n_num_meaningful': 1, + 'n_cat_meaningful': 2}, + {'n_rows': 1000, + 'n_numerical': 0, + 'n_categorical': 5, + 'cat_size': 3, + 'n_num_meaningful': 0, + 'n_cat_meaningful': 3}, + {'n_rows': 1000, + 'n_numerical': 5, + 'n_categorical': 5, + 'cat_size': 64, + 'n_num_meaningful': 0, + 'n_cat_meaningful': 2}, + {'n_rows': 1000, + 'n_numerical': 5, + 'n_categorical': 5, + 'cat_size': 3, + 'n_num_meaningful': 0, + 'n_cat_meaningful': 3}]) +def test_categorical_data(model, data_params): + # DecisionTrees are too slow for large category sizes. + if data_params['cat_size'] > 8 and 'RandomForest' in model: + pass + + X, y, meaningful_features = _make_categorical( + **data_params, + regression=model in FOREST_REGRESSORS, + return_tuple=True, + random_state=42) + rows, cols = X.shape + categorical_features = (np.arange(data_params['n_categorical']) + + data_params['n_numerical']) + + model = FOREST_CLASSIFIERS_REGRESSORS[model]( + random_state=42, categorical=categorical_features, + n_estimators=100).fit(X, y) + fi = model.feature_importances_ + bad_features = np.array([True]*cols) + bad_features[meaningful_features] = False + + good_ones = fi[meaningful_features] + print(good_ones) + bad_ones = fi[bad_features] + print(bad_ones) + + # all good features should be more important than all bad features. + assert np.all([np.all(x > bad_ones) for x in good_ones]) From 1e2bcfefe68ded1d44705109d0c5c160bc974c63 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 21 Jan 2019 13:53:57 +0100 Subject: [PATCH 43/54] mostly cosmetics --- sklearn/tree/_splitter.pyx | 22 +++++++++++----------- sklearn/tree/_tree.pyx | 11 ++++++++--- sklearn/tree/_utils.pyx | 9 +++++++++ 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 6a1535f5adb8d..7173f42ec17c5 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -340,8 +340,9 @@ cdef class BestSplitter(BaseDenseSplitter): self.presort), self.__getstate__()) - cdef void _breiman_sort_categories(self, SIZE_t start, SIZE_t end, INT32_t ncat, - SIZE_t ncat_present, const INT32_t *cat_offs, + cdef void _breiman_sort_categories(self, SIZE_t start, SIZE_t end, + INT32_t ncat, SIZE_t ncat_present, + const INT32_t *cat_offset, SIZE_t *sorted_cat) nogil: """The Breiman shortcut for finding the best split involves a preprocessing step wherein we sort the categories by @@ -360,23 +361,22 @@ cdef class BestSplitter(BaseDenseSplitter): SIZE_t cat, localcat SIZE_t q, partition_end DTYPE_t sort_value[64] - DTYPE_t sort_den[64] + DTYPE_t sort_density[64] - for cat in range(ncat): - sort_value[cat] = 0 - sort_den[cat] = 0 + memset(sort_value, 0, 64 * sizeof(DTYPE_t)) + memset(sort_density, 0, 64 * sizeof(DTYPE_t)) for q in range(start, end): cat = Xf[q] w = sample_weight[samples[q]] if sample_weight else 1.0 sort_value[cat] += w * (y[y_stride * samples[q]]) - sort_den[cat] += w + sort_density[cat] += w for localcat in range(ncat_present): - cat = localcat + cat_offs[localcat] - if sort_den[cat] == 0: # Avoid dividing zero by zero - sort_den[cat] = 1 - sort_value[localcat] = sort_value[cat] / sort_den[cat] + cat = localcat + cat_offset[localcat] + if sort_density[cat] == 0: # Avoid dividing zero by zero + sort_density[cat] = 1 + sort_value[localcat] = sort_value[cat] / sort_density[cat] sorted_cat[localcat] = cat sort(&sort_value[0], sorted_cat, ncat_present) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index a05079e17e49e..c71288c8a1a01 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -552,7 +552,8 @@ cdef class CategoryCacheMgr: free(self.bits[i]) free(self.bits) - cdef void populate(self, Node *nodes, SIZE_t n_nodes, INT32_t *n_categories): + cdef void populate(self, Node *nodes, SIZE_t n_nodes, + INT32_t *n_categories): cdef SIZE_t i cdef INT32_t ncat @@ -566,8 +567,12 @@ cdef class CategoryCacheMgr: if nodes[i].left_child != _TREE_LEAF: ncat = n_categories[nodes[i].feature] if ncat > 0: - safe_realloc(&self.bits[i], (ncat + 31) // 32, sizeof(UINT32_t)) - setup_cat_cache(self.bits[i], nodes[i].split_value.cat_split, ncat) + safe_realloc(&self.bits[i], + (ncat + 31) // 32, + sizeof(UINT32_t)) + setup_cat_cache(self.bits[i], + nodes[i].split_value.cat_split, + ncat) # ============================================================================= diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index b25d516118127..a3dcae3d4390a 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -110,6 +110,15 @@ cdef inline void setup_cat_cache(UINT32_t *cachebits, UINT64_t cat_split, for j in range(n_categories): val = rand_int(0, 2, &rng_seed) cachebits[j // 32] |= val << (j % 32) + """ + rng_seed = cat_split >> 32 + 10 + 1111111111 bits + 2 bytes + for j in range((n_categories + 31) // 32): + cachebits[j] = rand_int(0, 1 << 32, &rng_seed) + cachebits[j] &= (1 << (n_categories - (32 * j))) - 1 + """ else: # BestSplitter for j in range((n_categories + 31) // 32): From 35f273df6063ab8f76dd531c7e310e7681033e05 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 30 Jan 2019 10:26:31 +0100 Subject: [PATCH 44/54] fix typo in _splitters.pyx --- sklearn/tree/_splitter.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 7173f42ec17c5..e68b7be017f60 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -374,7 +374,7 @@ cdef class BestSplitter(BaseDenseSplitter): for localcat in range(ncat_present): cat = localcat + cat_offset[localcat] - if sort_density[cat] == 0: # Avoid dividing zero by zero + if sort_density[cat] == 0: # Avoid dividing by zero sort_density[cat] = 1 sort_value[localcat] = sort_value[cat] / sort_density[cat] sorted_cat[localcat] = cat From b0d73e021fa61dd0828911945f6fd73a5464d945 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 3 Feb 2019 13:51:31 +0100 Subject: [PATCH 45/54] n_categories as memview --- sklearn/tree/_splitter.pxd | 4 ++-- sklearn/tree/_splitter.pyx | 15 ++++++--------- sklearn/tree/_tree.pyx | 6 ++---- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 928fc934dbeb9..2772f3ba961cd 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -59,7 +59,7 @@ cdef class Splitter: cdef DOUBLE_t* y cdef SIZE_t y_stride cdef DOUBLE_t* sample_weight - cdef INT32_t *n_categories # (n_features,) array giving number of + cdef INT32_t[:] n_categories # (n_features,) array giving number of # categories (<0 for non-categorical) cdef UINT32_t* cat_cache # Cache buffer for fast categorical split evaluation @@ -82,7 +82,7 @@ cdef class Splitter: # Methods cdef int init(self, object X, np.ndarray y, DOUBLE_t* sample_weight, - INT32_t* n_categories, + INT32_t[:] n_categories, np.ndarray X_idx_sorted=*) except -1 cdef int node_reset(self, SIZE_t start, SIZE_t end, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index e68b7be017f60..9c89283a175de 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -98,7 +98,6 @@ cdef class Splitter: self.y = NULL self.y_stride = 0 self.sample_weight = NULL - self.n_categories = NULL self.cat_cache = NULL self.max_features = max_features @@ -115,7 +114,6 @@ cdef class Splitter: free(self.features) free(self.constant_features) free(self.feature_values) - free(self.n_categories) free(self.cat_cache) def __getstate__(self): @@ -128,7 +126,7 @@ cdef class Splitter: object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, - INT32_t* n_categories, + INT32_t[:] n_categories, np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter. @@ -199,14 +197,13 @@ cdef class Splitter: # Initialize the number of categories for each feature # A value of -1 indicates a non-categorical feature - safe_realloc(&self.n_categories, n_features, sizeof(INT32_t)) + self.n_categories = np.zeros((n_features,), dtype=np.int32) for i in range(n_features): - self.n_categories[i] = (-1 if n_categories == NULL + self.n_categories[i] = (-1 if n_categories is None else n_categories[i]) # If needed, allocate cache space for categorical splits - cdef INT32_t max_n_categories = max( - [self.n_categories[i] for i in range(n_features)]) + cdef INT32_t max_n_categories = max(self.n_categories) if max_n_categories > 0: safe_realloc(&self.cat_cache, (max_n_categories + 31) // 32, sizeof(UINT32_t)) @@ -298,7 +295,7 @@ cdef class BaseDenseSplitter(Splitter): object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, - INT32_t* n_categories, + INT32_t[:] n_categories, np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter @@ -1060,7 +1057,7 @@ cdef class BaseSparseSplitter(Splitter): object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, - INT32_t* n_categories, + INT32_t[:] n_categories, np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index c71288c8a1a01..5bdda08eeceb5 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -183,7 +183,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef INT32_t *n_categories_ptr = NULL if n_categories is not None: n_categories = np.asarray(n_categories, dtype=np.int32, order='C') - n_categories_ptr = n_categories.data # Initial capacity cdef int init_capacity @@ -205,7 +204,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double min_impurity_split = self.min_impurity_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr, n_categories_ptr, X_idx_sorted) + splitter.init(X, y, sample_weight_ptr, n_categories, X_idx_sorted) cdef SIZE_t start cdef SIZE_t end @@ -359,7 +358,6 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef INT32_t *n_categories_ptr = NULL if n_categories is not None: n_categories = np.asarray(n_categories, dtype=np.int32, order='C') - n_categories_ptr = n_categories.data # Parameters cdef Splitter splitter = self.splitter @@ -369,7 +367,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_split = self.min_samples_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr, n_categories_ptr, X_idx_sorted) + splitter.init(X, y, sample_weight_ptr, n_categories, X_idx_sorted) cdef PriorityHeap frontier = PriorityHeap(INITIAL_STACK_SIZE) cdef PriorityHeapRecord record From bb4abfe450f1c036c11489537f329936b8f53960 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 3 Feb 2019 14:30:17 +0100 Subject: [PATCH 46/54] minor cleanup --- sklearn/tree/_splitter.pyx | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 9c89283a175de..f92c00f455ce6 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -180,7 +180,8 @@ cdef class Splitter: self.weighted_n_samples = weighted_n_samples cdef SIZE_t n_features = X.shape[1] - cdef SIZE_t* features = safe_realloc(&self.features, n_features, sizeof(SIZE_t)) + cdef SIZE_t* features = safe_realloc(&self.features, n_features, + sizeof(SIZE_t)) for i in range(n_features): features[i] = i @@ -197,15 +198,17 @@ cdef class Splitter: # Initialize the number of categories for each feature # A value of -1 indicates a non-categorical feature - self.n_categories = np.zeros((n_features,), dtype=np.int32) - for i in range(n_features): - self.n_categories[i] = (-1 if n_categories is None - else n_categories[i]) + if n_categories is None: + self.n_categories = np.array([-1] * n_features, dtype=np.int32) + else: + self.n_categories = np.empty_like(n_categories, dtype=np.int32) + self.n_categories[:] = n_categories # If needed, allocate cache space for categorical splits cdef INT32_t max_n_categories = max(self.n_categories) if max_n_categories > 0: - safe_realloc(&self.cat_cache, (max_n_categories + 31) // 32, sizeof(UINT32_t)) + safe_realloc(&self.cat_cache, (max_n_categories + 31) // 32, + sizeof(UINT32_t)) return 0 @@ -360,6 +363,8 @@ cdef class BestSplitter(BaseDenseSplitter): DTYPE_t sort_value[64] DTYPE_t sort_density[64] + # categorical features with more than 64 categories are not supported + # here. memset(sort_value, 0, 64 * sizeof(DTYPE_t)) memset(sort_density, 0, 64 * sizeof(DTYPE_t)) From 532061ff53db321cabcaa213e520a055ecec36df Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 3 Feb 2019 16:51:53 +0100 Subject: [PATCH 47/54] cat_split as a BitSet --- sklearn/tree/_splitter.pxd | 3 ++- sklearn/tree/_splitter.pyx | 46 +++++++++++++++++++++----------------- sklearn/tree/_utils.pxd | 14 ++++++++++++ sklearn/tree/_utils.pyx | 33 +++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 21 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 2772f3ba961cd..9ad0670a04a72 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -14,7 +14,7 @@ import numpy as np cimport numpy as np -from ._utils cimport SplitValue, SplitRecord +from ._utils cimport SplitValue, SplitRecord, BitSet from ._criterion cimport Criterion @@ -62,6 +62,7 @@ cdef class Splitter: cdef INT32_t[:] n_categories # (n_features,) array giving number of # categories (<0 for non-categorical) cdef UINT32_t* cat_cache # Cache buffer for fast categorical split evaluation + cdef BitSet cat_split # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index f92c00f455ce6..0df28fdf94ba1 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -36,6 +36,8 @@ from ._utils cimport RAND_R_MAX from ._utils cimport safe_realloc from ._utils cimport setup_cat_cache from ._utils cimport goes_left +from ._utils cimport BitSet + cdef double INFINITY = np.inf @@ -288,6 +290,7 @@ cdef class BaseDenseSplitter(Splitter): self.X_idx_sorted_stride = 0 self.sample_mask = NULL self.presort = presort + self.cat_split = BitSet() def __dealloc__(self): """Destructor.""" @@ -436,7 +439,7 @@ cdef class BestSplitter(BaseDenseSplitter): cdef DTYPE_t current_feature_value cdef SIZE_t partition_end cdef bint is_categorical - cdef UINT64_t cat_split, cat_idx, ncat_present + cdef UINT64_t cat_idx, ncat_present cdef INT32_t cat_offs[64] cdef bint breiman_shortcut = self.breiman_shortcut cdef SIZE_t sorted_cat[64] @@ -527,13 +530,13 @@ cdef class BestSplitter(BaseDenseSplitter): # Identify the number of categories present in this node is_categorical = self.n_categories[current.feature] > 0 if is_categorical: - cat_split = 0 + self.cat_split.reset_all() ncat_present = 0 for i in range(start, end): # Xf[i] < 64 already verified in tree.py - cat_split |= ( 1) << ( Xf[i]) + self.cat_split.set(Xf[i]) for i in range(self.n_categories[current.feature]): - if (cat_split >> i) & 1: + if self.cat_split.get(i): cat_offs[ncat_present] = i - ncat_present ncat_present += 1 if ncat_present <= 3: @@ -555,38 +558,41 @@ cdef class BestSplitter(BaseDenseSplitter): if cat_idx >= ncat_present: break - cat_split = 0 + self.cat_split.reset_all() for ui in range(cat_idx): - cat_split |= ( 1) << sorted_cat[ui] - if cat_split & 1: - cat_split = (~cat_split) & ( - (~( 0)) >> (64 - self.n_categories[current.feature])) + self.cat_split.set(sorted_cat[ui]) + # check if the first bit is 1, if yes, flip all + if self.cat_split.get(0): + self.cat_split.flip_all( + self.n_categories[current.feature]) else: if cat_idx >= ( 1) << (ncat_present - 1): break - # Expand the bits of (2 * cat_idx) out into cat_split - # We double cat_idx to avoid double-counting equivalent splits - # This also ensures that cat_split & 1 == 0 as required - cat_split = 0 - for ui in range(ncat_present): - cat_split |= ((cat_idx << 1) & - (( 1) << ui)) << cat_offs[ui] + # Expand the bits of (2 * cat_idx) out into + # cat_split. We double cat_idx to avoid + # double-counting equivalent splits. This also + # ensures that cat_split & 1 == 0 as required + self.cat_split.from_template(cat_idx << 1, + cat_offs, + ncat_present) # Partition j = start partition_end = end while j < partition_end: - if (cat_split >> ( Xf[j])) & 1: + if self.cat_split.get(Xf[j]): j += 1 else: partition_end -= 1 - Xf[j], Xf[partition_end] = Xf[partition_end], Xf[j] + Xf[j], Xf[partition_end] = ( + Xf[partition_end], Xf[j]) samples[j], samples[partition_end] = ( samples[partition_end], samples[j]) current.pos = j - # Must reset criterion since we've reordered the samples + # Must reset criterion since we've reordered the + # samples self.criterion.reset() else: # Non-categorical feature @@ -622,7 +628,7 @@ cdef class BestSplitter(BaseDenseSplitter): if current_proxy_improvement > best_proxy_improvement: best_proxy_improvement = current_proxy_improvement if is_categorical: - current.split_value.cat_split = cat_split + current.split_value.cat_split = self.cat_split.value else: current.split_value.threshold = (Xf[p - 1] + Xf[p]) / 2.0 if (current.split_value.threshold == Xf[p] diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index e88c27d53f06a..f059759cafa6b 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -230,3 +230,17 @@ cdef class WeightedMedianCalculator: self, DOUBLE_t data, DOUBLE_t weight, DOUBLE_t original_median) nogil cdef DOUBLE_t get_median(self) nogil + + +cdef class BitSet: + cdef UINT64_t value + + cdef inline void reset_all(self) nogil + cdef inline void set(self, SIZE_t i) nogil + cdef inline void reset(self, SIZE_t i) nogil + cdef inline void flip(self, SIZE_t i) nogil + cdef inline void flip_all(self, SIZE_t n_low_bits) nogil + cdef inline int get(self, SIZE_t i) nogil + cdef inline void from_template(self, UINT64_t template, + INT32_t *cat_offs, + SIZE_t ncats_present) nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index a3dcae3d4390a..1e6293fa58678 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -723,3 +723,36 @@ cdef class WeightedMedianCalculator: if self.sum_w_0_k > (self.total_weight / 2.0): # whole median return self.samples.get_value_from_index(self.k-1) + + +cdef class BitSet: + """Easy bit operations on a UINT64_t value""" + def __cinit__(self): + self.value = 0 + + cdef inline void reset_all(self) nogil: + self.value = 0 + + cdef inline void set(self, SIZE_t i) nogil: + self.value |= ( 1) << i + + cdef inline void reset(self, SIZE_t i) nogil: + self.value &= ~(( 1) << i) + + cdef inline void flip(self, SIZE_t i) nogil: + self.value ^= ( 1) << i + + cdef inline void flip_all(self, SIZE_t n_low_bits) nogil: + self.value = (~self.value) & ((~( 0)) >> (64 - n_low_bits)) + + cdef inline int get(self, SIZE_t i) nogil: + return (self.value >> i) & ( 1) + + cdef inline void from_template(self, UINT64_t template, + INT32_t *cat_offs, + SIZE_t ncats_present) nogil: + cdef SIZE_t i + self.value = 0 + for i in range(ncats_present): + self.value |= (template & + (( 1) << i)) << cat_offs[i] From 15a184efb8d113026168a5a959a0b5878666162e Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 4 Feb 2019 10:39:50 +0100 Subject: [PATCH 48/54] minor fix (n_categories < 0) --- sklearn/tree/_utils.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 1e6293fa58678..983a8a1aafec6 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -130,7 +130,7 @@ cdef inline bint goes_left(DTYPE_t feature_value, SplitValue split, """Determine whether a sample goes to the left or right child node.""" cdef SIZE_t idx, shift - if n_categories < 1: + if n_categories < 0: # Non-categorical feature return feature_value <= split.threshold else: From 38e7b95aa29569ce77603787b34fb5c4806a58a7 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 4 Feb 2019 21:27:26 +0100 Subject: [PATCH 49/54] move from class BitSet to BITSET_t and functions --- sklearn/ensemble/_gradient_boosting.pyx | 11 +-- sklearn/tree/_splitter.pxd | 6 +- sklearn/tree/_splitter.pyx | 34 +++---- sklearn/tree/_tree.pxd | 5 +- sklearn/tree/_tree.pyx | 21 ++--- sklearn/tree/_utils.pxd | 27 +++--- sklearn/tree/_utils.pyx | 113 ++++++++++++++---------- 7 files changed, 122 insertions(+), 95 deletions(-) diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index 59ded374996ff..c1c3f4ffeb1cf 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -25,6 +25,7 @@ from sklearn.tree._tree cimport DTYPE_t from sklearn.tree._tree cimport SIZE_t from sklearn.tree._tree cimport INT32_t from sklearn.tree._tree cimport UINT32_t +from sklearn.tree._tree cimport BITSET_t from sklearn.tree._utils cimport safe_realloc from sklearn.tree._utils cimport goes_left @@ -52,7 +53,7 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, Py_ssize_t n_samples, Py_ssize_t n_features, INT32_t* n_categories, - UINT32_t** cachebits, + BITSET_t** cachebits, float64 *out): """Predicts output for regression tree and stores it in ``out[i, k]``. @@ -90,7 +91,7 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, n_categories : INT32_t pointer Array of length n_features containing the number of categories (for categorical features) or -1 (for non-categorical features) - cachebits : UINT32_t pointer pointer + cachebits : BITSET_t pointer pointer Array of length node_count containing category cache buffers for categorical features out : np.float64_t pointer @@ -100,7 +101,7 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, """ cdef Py_ssize_t i cdef Node *node - cdef UINT32_t* node_cache + cdef BITSET_t* node_cache for i in range(n_samples): node = root_node @@ -322,8 +323,8 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, cdef double total_weight = 0.0 cdef Node *current_node cdef SIZE_t[::1] node_stack = np_zeros((stack_capacity,), dtype=np.intp) - cdef UINT32_t** cachebits - cdef UINT32_t* node_cache + cdef BITSET_t** cachebits + cdef BITSET_t* node_cache # Make category cache buffers for this tree's nodes cache_mgr = CategoryCacheMgr() diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 9ad0670a04a72..eb1422ab4f0f1 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -14,7 +14,7 @@ import numpy as np cimport numpy as np -from ._utils cimport SplitValue, SplitRecord, BitSet +from ._utils cimport SplitValue, SplitRecord, BITSET_t from ._criterion cimport Criterion @@ -61,8 +61,8 @@ cdef class Splitter: cdef DOUBLE_t* sample_weight cdef INT32_t[:] n_categories # (n_features,) array giving number of # categories (<0 for non-categorical) - cdef UINT32_t* cat_cache # Cache buffer for fast categorical split evaluation - cdef BitSet cat_split + cdef BITSET_t* cat_cache # Cache buffer for fast categorical split evaluation + cdef BITSET_t cat_split # cat_split as a bitset # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 0df28fdf94ba1..429cbc36b3d57 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -36,7 +36,8 @@ from ._utils cimport RAND_R_MAX from ._utils cimport safe_realloc from ._utils cimport setup_cat_cache from ._utils cimport goes_left -from ._utils cimport BitSet +from ._utils cimport (BITSET_t, bs_get, bs_set, bs_flip_all, bs_reset_all, + bs_from_template) cdef double INFINITY = np.inf @@ -209,8 +210,8 @@ cdef class Splitter: # If needed, allocate cache space for categorical splits cdef INT32_t max_n_categories = max(self.n_categories) if max_n_categories > 0: - safe_realloc(&self.cat_cache, (max_n_categories + 31) // 32, - sizeof(UINT32_t)) + cache_size = (max_n_categories + 63) // 64 + safe_realloc(&self.cat_cache, cache_size, sizeof(BITSET_t)) return 0 @@ -290,7 +291,7 @@ cdef class BaseDenseSplitter(Splitter): self.X_idx_sorted_stride = 0 self.sample_mask = NULL self.presort = presort - self.cat_split = BitSet() + self.cat_split = 0 def __dealloc__(self): """Destructor.""" @@ -443,6 +444,7 @@ cdef class BestSplitter(BaseDenseSplitter): cdef INT32_t cat_offs[64] cdef bint breiman_shortcut = self.breiman_shortcut cdef SIZE_t sorted_cat[64] + cdef BITSET_t *cat_split_ptr = &self.cat_split _init_split(&best, end) @@ -530,13 +532,13 @@ cdef class BestSplitter(BaseDenseSplitter): # Identify the number of categories present in this node is_categorical = self.n_categories[current.feature] > 0 if is_categorical: - self.cat_split.reset_all() + bs_reset_all(cat_split_ptr) ncat_present = 0 for i in range(start, end): # Xf[i] < 64 already verified in tree.py - self.cat_split.set(Xf[i]) + bs_set(cat_split_ptr, Xf[i]) for i in range(self.n_categories[current.feature]): - if self.cat_split.get(i): + if bs_get(self.cat_split, i): cat_offs[ncat_present] = i - ncat_present ncat_present += 1 if ncat_present <= 3: @@ -558,12 +560,13 @@ cdef class BestSplitter(BaseDenseSplitter): if cat_idx >= ncat_present: break - self.cat_split.reset_all() + bs_reset_all(cat_split_ptr) for ui in range(cat_idx): - self.cat_split.set(sorted_cat[ui]) + bs_set(cat_split_ptr, sorted_cat[ui]) # check if the first bit is 1, if yes, flip all - if self.cat_split.get(0): - self.cat_split.flip_all( + if bs_get(self.cat_split, 0): + bs_flip_all( + cat_split_ptr, self.n_categories[current.feature]) else: if cat_idx >= ( 1) << (ncat_present - 1): @@ -573,15 +576,14 @@ cdef class BestSplitter(BaseDenseSplitter): # cat_split. We double cat_idx to avoid # double-counting equivalent splits. This also # ensures that cat_split & 1 == 0 as required - self.cat_split.from_template(cat_idx << 1, - cat_offs, - ncat_present) + bs_from_template(cat_split_ptr, cat_idx << 1, + cat_offs, ncat_present) # Partition j = start partition_end = end while j < partition_end: - if self.cat_split.get(Xf[j]): + if bs_get(self.cat_split, Xf[j]): j += 1 else: partition_end -= 1 @@ -628,7 +630,7 @@ cdef class BestSplitter(BaseDenseSplitter): if current_proxy_improvement > best_proxy_improvement: best_proxy_improvement = current_proxy_improvement if is_categorical: - current.split_value.cat_split = self.cat_split.value + current.split_value.cat_split = self.cat_split else: current.split_value.threshold = (Xf[p - 1] + Xf[p]) / 2.0 if (current.split_value.threshold == Xf[p] diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 08ac5f33f2d08..d8583af8a55eb 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -12,6 +12,8 @@ # See _tree.pyx for details. +from cpython cimport Py_INCREF, PyObject + import numpy as np cimport numpy as np @@ -24,6 +26,7 @@ ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer from ._utils cimport SplitValue from ._utils cimport SplitRecord from ._utils cimport Node +from ._utils cimport BITSET_t from ._splitter cimport Splitter @@ -31,7 +34,7 @@ cdef class CategoryCacheMgr: # Class to manage the category cache memory during Tree.apply() cdef SIZE_t n_nodes - cdef UINT32_t **bits + cdef BITSET_t **bits cdef void populate(self, Node *nodes, SIZE_t n_nodes, INT32_t *n_categories) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 5bdda08eeceb5..f3ef8fe28c709 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -565,9 +565,10 @@ cdef class CategoryCacheMgr: if nodes[i].left_child != _TREE_LEAF: ncat = n_categories[nodes[i].feature] if ncat > 0: + cache_size = (ncat + 63) // 64 safe_realloc(&self.bits[i], - (ncat + 31) // 32, - sizeof(UINT32_t)) + cache_size, + sizeof(BITSET_t)) setup_cat_cache(self.bits[i], nodes[i].split_value.cat_split, ncat) @@ -905,8 +906,8 @@ cdef class Tree: cdef SIZE_t i = 0 cache_mgr = CategoryCacheMgr() cache_mgr.populate(self.nodes, self.node_count, self.n_categories) - cdef UINT32_t** cat_caches = cache_mgr.bits - cdef UINT32_t* cache = NULL + cdef BITSET_t** cat_caches = cache_mgr.bits + cdef BITSET_t* cache = NULL with nogil: for i in range(n_samples): @@ -963,8 +964,8 @@ cdef class Tree: cdef INT32_t k = 0 cache_mgr = CategoryCacheMgr() cache_mgr.populate(self.nodes, self.node_count, self.n_categories) - cdef UINT32_t** cat_caches = cache_mgr.bits - cdef UINT32_t* cache = NULL + cdef BITSET_t** cat_caches = cache_mgr.bits + cdef BITSET_t* cache = NULL # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify @@ -1049,8 +1050,8 @@ cdef class Tree: cdef SIZE_t i = 0 cache_mgr = CategoryCacheMgr() cache_mgr.populate(self.nodes, self.node_count, self.n_categories) - cdef UINT32_t** cat_caches = cache_mgr.bits - cdef UINT32_t* cache = NULL + cdef BITSET_t** cat_caches = cache_mgr.bits + cdef BITSET_t* cache = NULL with nogil: for i in range(n_samples): @@ -1124,8 +1125,8 @@ cdef class Tree: cdef INT32_t k = 0 cache_mgr = CategoryCacheMgr() cache_mgr.populate(self.nodes, self.node_count, self.n_categories) - cdef UINT32_t** cat_caches = cache_mgr.bits - cdef UINT32_t* cache = NULL + cdef BITSET_t** cat_caches = cache_mgr.bits + cdef BITSET_t* cache = NULL # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index f059759cafa6b..31125f1db23c6 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -21,6 +21,7 @@ ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer +ctypedef UINT64_t BITSET_t ctypedef union SplitValue: # Union type to generalize the concept of a threshold to @@ -97,6 +98,7 @@ ctypedef fused realloc_ptr: (void**) (INT32_t*) (UINT32_t*) + (BITSET_t*) cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems, size_t elem_bytes) nogil except * @@ -116,12 +118,12 @@ cdef double rand_uniform(double low, double high, cdef double log(double x) nogil -cdef void setup_cat_cache(UINT32_t* cachebits, UINT64_t cat_split, +cdef void setup_cat_cache(BITSET_t* cachebits, UINT64_t cat_split, INT32_t n_categories) nogil cdef bint goes_left(DTYPE_t feature_value, SplitValue split, - INT32_t n_categories, UINT32_t* cachebits) nogil + INT32_t n_categories, BITSET_t* cachebits) nogil # ============================================================================= @@ -232,15 +234,12 @@ cdef class WeightedMedianCalculator: cdef DOUBLE_t get_median(self) nogil -cdef class BitSet: - cdef UINT64_t value - - cdef inline void reset_all(self) nogil - cdef inline void set(self, SIZE_t i) nogil - cdef inline void reset(self, SIZE_t i) nogil - cdef inline void flip(self, SIZE_t i) nogil - cdef inline void flip_all(self, SIZE_t n_low_bits) nogil - cdef inline int get(self, SIZE_t i) nogil - cdef inline void from_template(self, UINT64_t template, - INT32_t *cat_offs, - SIZE_t ncats_present) nogil +cdef void bs_reset_all(BITSET_t *value) nogil +cdef void bs_set(BITSET_t *value, SIZE_t i) nogil +cdef void bs_reset(BITSET_t *value, SIZE_t i) nogil +cdef void bs_flip(BITSET_t *value, SIZE_t i) nogil +cdef void bs_flip_all(BITSET_t *value, SIZE_t n_low_bits) nogil +cdef bint bs_get(BITSET_t value, SIZE_t i) nogil +cdef void bs_from_template(BITSET_t *value, UINT64_t template, + INT32_t *cat_offs, + SIZE_t ncats_present) nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 983a8a1aafec6..6bd9a1c606949 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -94,40 +94,66 @@ cdef inline double log(double x) nogil: return ln(x) / ln(2.0) -cdef inline void setup_cat_cache(UINT32_t *cachebits, UINT64_t cat_split, +cdef inline void setup_cat_cache(BITSET_t* cachebits, UINT64_t cat_split, INT32_t n_categories) nogil: """Populate the bits of the category cache from a split. + + This function populates cat_split into cachebits. In the case of a + BestSplitter, cachebits is an array of length 1, i.e. maximum 64 categories + supported, and cachebits[0] = cat_split. However, in the case of a random + splitter, there is no limit for the number of categories on a feature, and + cat_split stores the 32 bit random_seed on the highest 32 bit of the + cat_split to generate the random split. The lowest bit of cat_split defines + if it should be interpreted as a random split or a deterministic one, i.e. + 1 indicates a random split. """ cdef INT32_t j cdef UINT32_t rng_seed, val - + cdef SIZE_t cache_size = (n_categories + 63) // 64 if n_categories > 0: if cat_split & 1: # RandomSplitter - for j in range((n_categories + 31) // 32): - cachebits[j] = 0 + for j in range(cache_size): + bs_reset_all(&cachebits[j]) rng_seed = cat_split >> 32 for j in range(n_categories): val = rand_int(0, 2, &rng_seed) - cachebits[j // 32] |= val << (j % 32) - """ - rng_seed = cat_split >> 32 - 10 - 1111111111 bits - 2 bytes - for j in range((n_categories + 31) // 32): - cachebits[j] = rand_int(0, 1 << 32, &rng_seed) - cachebits[j] &= (1 << (n_categories - (32 * j))) - 1 - """ + if not val: + continue + bs_set(&cachebits[j // 64], j % 64) else: # BestSplitter - for j in range((n_categories + 31) // 32): - cachebits[j] = (cat_split >> (j * 32)) & 0xFFFFFFFF + # In practice, cache_size here should ALWAYS be 1 + # XXX TODO: check cache_size == 1? + cachebits[0] = cat_split cdef inline bint goes_left(DTYPE_t feature_value, SplitValue split, - INT32_t n_categories, UINT32_t* cachebits) nogil: - """Determine whether a sample goes to the left or right child node.""" + INT32_t n_categories, BITSET_t *cachebits) nogil: + """Determine whether a sample goes to the left or right child node. + + Attributes + ---------- + feature_value : DTYPE_t + The value of the feature for which the decision needs to be made. + + split : SplitValue + The union (of DOUBLE_t and BITSET_t) indicating the split. However, it + is used (as a DOUBLE_t) only for numerical features. + + n_categories : INT32_t + The number of categories present in the feature in question. The + feature is considered a numerical one and not a categorical one if + n_categories is negative. + + cachebits : BITSET_t* + The array containing the expantion of split.cat_split. The function + setup_cat_cache is the one filling it. + + Returns + ------- + bint : Indicating whether the left branch should be used. + """ cdef SIZE_t idx, shift if n_categories < 0: @@ -136,9 +162,9 @@ cdef inline bint goes_left(DTYPE_t feature_value, SplitValue split, else: # Categorical feature, using bit cache if ( feature_value) < n_categories: - idx = ( feature_value) // 32 - shift = ( feature_value) % 32 - return (cachebits[idx] >> shift) & 1 + idx = ( feature_value) // 64 + offset = ( feature_value) % 64 + return bs_get(cachebits[idx], offset) else: return 0 @@ -725,34 +751,29 @@ cdef class WeightedMedianCalculator: return self.samples.get_value_from_index(self.k-1) -cdef class BitSet: - """Easy bit operations on a UINT64_t value""" - def __cinit__(self): - self.value = 0 - - cdef inline void reset_all(self) nogil: - self.value = 0 +cdef inline void bs_reset_all(BITSET_t *value) nogil: + value[0] = 0 - cdef inline void set(self, SIZE_t i) nogil: - self.value |= ( 1) << i +cdef inline void bs_set(BITSET_t *value, SIZE_t i) nogil: + value[0] |= ( 1) << i - cdef inline void reset(self, SIZE_t i) nogil: - self.value &= ~(( 1) << i) +cdef inline void bs_reset(BITSET_t *value, SIZE_t i) nogil: + value[0] &= ~(( 1) << i) - cdef inline void flip(self, SIZE_t i) nogil: - self.value ^= ( 1) << i +cdef inline void bs_flip(BITSET_t *value, SIZE_t i) nogil: + value[0] ^= ( 1) << i - cdef inline void flip_all(self, SIZE_t n_low_bits) nogil: - self.value = (~self.value) & ((~( 0)) >> (64 - n_low_bits)) +cdef inline void bs_flip_all(BITSET_t *value, SIZE_t n_low_bits) nogil: + value[0] = (~value[0]) & ((~( 0)) >> (64 - n_low_bits)) - cdef inline int get(self, SIZE_t i) nogil: - return (self.value >> i) & ( 1) +cdef inline bint bs_get(BITSET_t value, SIZE_t i) nogil: + return (value >> i) & ( 1) - cdef inline void from_template(self, UINT64_t template, - INT32_t *cat_offs, - SIZE_t ncats_present) nogil: - cdef SIZE_t i - self.value = 0 - for i in range(ncats_present): - self.value |= (template & - (( 1) << i)) << cat_offs[i] +cdef inline void bs_from_template(BITSET_t *value, UINT64_t template, + INT32_t *cat_offs, + SIZE_t ncats_present) nogil: + cdef SIZE_t i + value[0] = 0 + for i in range(ncats_present): + value[0] |= (template & + (( 1) << i)) << cat_offs[i] From 1a78f1a632d2c1db643a874226918f43a04ae78d Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 6 Feb 2019 11:13:52 +0100 Subject: [PATCH 50/54] better docstring for cache functions --- sklearn/tree/_utils.pyx | 42 +++++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 6bd9a1c606949..6324cf20e49bd 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -94,18 +94,31 @@ cdef inline double log(double x) nogil: return ln(x) / ln(2.0) -cdef inline void setup_cat_cache(BITSET_t* cachebits, UINT64_t cat_split, +cdef inline void setup_cat_cache(BITSET_t* cachebits, BITSET_t cat_split, INT32_t n_categories) nogil: """Populate the bits of the category cache from a split. - This function populates cat_split into cachebits. In the case of a - BestSplitter, cachebits is an array of length 1, i.e. maximum 64 categories - supported, and cachebits[0] = cat_split. However, in the case of a random - splitter, there is no limit for the number of categories on a feature, and - cat_split stores the 32 bit random_seed on the highest 32 bit of the - cat_split to generate the random split. The lowest bit of cat_split defines - if it should be interpreted as a random split or a deterministic one, i.e. - 1 indicates a random split. + Attributes + ---------- + cachebits : BITSET_t* + This is a pointer to the output array. The size of the array should be + ``ceil(n_categories / 64)``. This function assumes the required + memory is allocated for the array by the caller. + + cat_split : BITSET_t + If ``least significant bit == 0``: + It stores the split of the maximum 64 categories in its bits. + This is used in `BestSplitter`, and without loss of generality it + is assumed to be even, i.e. for any odd value there is an + equivalent even ``cat_split``. + If ``least significant bit == 1``: + It is a random split, and the 32 most significant bits of + ``cat_split`` contain the random seed of the split. The + ``n_categories`` lowest bits of ``cachebits`` are then filled with + random zeros and ones given the random seed. + + n_categories : INT32_t + The number of categories. """ cdef INT32_t j cdef UINT32_t rng_seed, val @@ -132,6 +145,14 @@ cdef inline bint goes_left(DTYPE_t feature_value, SplitValue split, INT32_t n_categories, BITSET_t *cachebits) nogil: """Determine whether a sample goes to the left or right child node. + For numerical features, ``(-inf, split.threshold]`` is the left child, and + ``(split.threshold, inf)`` the right child. + + For categorical features, if the corresponding bit for the category is set + in cachebits, the left child isused, and if not set, the right child. If + the given input category is larger than the ``n_categories``, the right + child is assumed. + Attributes ---------- feature_value : DTYPE_t @@ -152,7 +173,8 @@ cdef inline bint goes_left(DTYPE_t feature_value, SplitValue split, Returns ------- - bint : Indicating whether the left branch should be used. + result : bint + Indicating whether the left branch should be used. """ cdef SIZE_t idx, shift From 646a86ab796db98a0be46cc20667bf5a9d4ffb97 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 6 Feb 2019 14:48:23 +0100 Subject: [PATCH 51/54] (gbc) fix realloc size param --- sklearn/neighbors/quad_tree.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/quad_tree.pyx b/sklearn/neighbors/quad_tree.pyx index b2a7de0d1ebed..26491e924e0f6 100644 --- a/sklearn/neighbors/quad_tree.pyx +++ b/sklearn/neighbors/quad_tree.pyx @@ -605,7 +605,7 @@ cdef class _QuadTree: else: capacity = 2 * self.capacity - safe_realloc(&self.cells, capacity, sizeof(self.cells[0])) + safe_realloc(&self.cells, capacity, sizeof(Cell)) # if capacity smaller than cell_count, adjust the counter if capacity < self.cell_count: From 6be8edd5ab9d0e495f2070f8973c51f12581855e Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 6 Feb 2019 15:19:16 +0100 Subject: [PATCH 52/54] bs_* with no pointers --- sklearn/tree/_splitter.pxd | 1 - sklearn/tree/_splitter.pyx | 31 ++++++++++++++++--------------- sklearn/tree/_tree.pxd | 2 -- sklearn/tree/_tree.pyx | 2 -- sklearn/tree/_utils.pxd | 15 +++++++-------- sklearn/tree/_utils.pyx | 36 +++++++++++++++++------------------- 6 files changed, 40 insertions(+), 47 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index eb1422ab4f0f1..7aaf4f455d1fd 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -62,7 +62,6 @@ cdef class Splitter: cdef INT32_t[:] n_categories # (n_features,) array giving number of # categories (<0 for non-categorical) cdef BITSET_t* cat_cache # Cache buffer for fast categorical split evaluation - cdef BITSET_t cat_split # cat_split as a bitset # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 429cbc36b3d57..de49219abc19a 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -36,7 +36,7 @@ from ._utils cimport RAND_R_MAX from ._utils cimport safe_realloc from ._utils cimport setup_cat_cache from ._utils cimport goes_left -from ._utils cimport (BITSET_t, bs_get, bs_set, bs_flip_all, bs_reset_all, +from ._utils cimport (BITSET_t, bs_get, bs_set, bs_flip_all, bs_from_template) @@ -291,7 +291,6 @@ cdef class BaseDenseSplitter(Splitter): self.X_idx_sorted_stride = 0 self.sample_mask = NULL self.presort = presort - self.cat_split = 0 def __dealloc__(self): """Destructor.""" @@ -444,7 +443,7 @@ cdef class BestSplitter(BaseDenseSplitter): cdef INT32_t cat_offs[64] cdef bint breiman_shortcut = self.breiman_shortcut cdef SIZE_t sorted_cat[64] - cdef BITSET_t *cat_split_ptr = &self.cat_split + cdef BITSET_t cat_split = 0 _init_split(&best, end) @@ -532,13 +531,13 @@ cdef class BestSplitter(BaseDenseSplitter): # Identify the number of categories present in this node is_categorical = self.n_categories[current.feature] > 0 if is_categorical: - bs_reset_all(cat_split_ptr) + cat_split = 0 ncat_present = 0 for i in range(start, end): # Xf[i] < 64 already verified in tree.py - bs_set(cat_split_ptr, Xf[i]) + cat_split = bs_set(cat_split, Xf[i]) for i in range(self.n_categories[current.feature]): - if bs_get(self.cat_split, i): + if bs_get(cat_split, i): cat_offs[ncat_present] = i - ncat_present ncat_present += 1 if ncat_present <= 3: @@ -560,13 +559,14 @@ cdef class BestSplitter(BaseDenseSplitter): if cat_idx >= ncat_present: break - bs_reset_all(cat_split_ptr) + cat_split = 0 for ui in range(cat_idx): - bs_set(cat_split_ptr, sorted_cat[ui]) + cat_split = bs_set(cat_split, + sorted_cat[ui]) # check if the first bit is 1, if yes, flip all - if bs_get(self.cat_split, 0): - bs_flip_all( - cat_split_ptr, + if bs_get(cat_split, 0): + cat_split = bs_flip_all( + cat_split, self.n_categories[current.feature]) else: if cat_idx >= ( 1) << (ncat_present - 1): @@ -576,14 +576,15 @@ cdef class BestSplitter(BaseDenseSplitter): # cat_split. We double cat_idx to avoid # double-counting equivalent splits. This also # ensures that cat_split & 1 == 0 as required - bs_from_template(cat_split_ptr, cat_idx << 1, - cat_offs, ncat_present) + cat_split = bs_from_template( + cat_idx << 1, + cat_offs, ncat_present) # Partition j = start partition_end = end while j < partition_end: - if bs_get(self.cat_split, Xf[j]): + if bs_get(cat_split, Xf[j]): j += 1 else: partition_end -= 1 @@ -630,7 +631,7 @@ cdef class BestSplitter(BaseDenseSplitter): if current_proxy_improvement > best_proxy_improvement: best_proxy_improvement = current_proxy_improvement if is_categorical: - current.split_value.cat_split = self.cat_split + current.split_value.cat_split = cat_split else: current.split_value.threshold = (Xf[p - 1] + Xf[p]) / 2.0 if (current.split_value.threshold == Xf[p] diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index d8583af8a55eb..3839f837ce2d3 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -12,8 +12,6 @@ # See _tree.pyx for details. -from cpython cimport Py_INCREF, PyObject - import numpy as np cimport numpy as np diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index f3ef8fe28c709..926aeb35e1758 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -180,7 +180,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if sample_weight is not None: sample_weight_ptr = sample_weight.data - cdef INT32_t *n_categories_ptr = NULL if n_categories is not None: n_categories = np.asarray(n_categories, dtype=np.int32, order='C') @@ -355,7 +354,6 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if sample_weight is not None: sample_weight_ptr = sample_weight.data - cdef INT32_t *n_categories_ptr = NULL if n_categories is not None: n_categories = np.asarray(n_categories, dtype=np.int32, order='C') diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 31125f1db23c6..121cdaafc8418 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -234,12 +234,11 @@ cdef class WeightedMedianCalculator: cdef DOUBLE_t get_median(self) nogil -cdef void bs_reset_all(BITSET_t *value) nogil -cdef void bs_set(BITSET_t *value, SIZE_t i) nogil -cdef void bs_reset(BITSET_t *value, SIZE_t i) nogil -cdef void bs_flip(BITSET_t *value, SIZE_t i) nogil -cdef void bs_flip_all(BITSET_t *value, SIZE_t n_low_bits) nogil +cdef BITSET_t bs_set(BITSET_t value, SIZE_t i) nogil +cdef BITSET_t bs_reset(BITSET_t value, SIZE_t i) nogil +cdef BITSET_t bs_flip(BITSET_t value, SIZE_t i) nogil +cdef BITSET_t bs_flip_all(BITSET_t value, SIZE_t n_low_bits) nogil cdef bint bs_get(BITSET_t value, SIZE_t i) nogil -cdef void bs_from_template(BITSET_t *value, UINT64_t template, - INT32_t *cat_offs, - SIZE_t ncats_present) nogil +cdef BITSET_t bs_from_template(UINT64_t template, + INT32_t *cat_offs, + SIZE_t ncats_present) nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 6324cf20e49bd..619b0937a77e9 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -127,13 +127,13 @@ cdef inline void setup_cat_cache(BITSET_t* cachebits, BITSET_t cat_split, if cat_split & 1: # RandomSplitter for j in range(cache_size): - bs_reset_all(&cachebits[j]) + cachebits[j] = 0 rng_seed = cat_split >> 32 for j in range(n_categories): val = rand_int(0, 2, &rng_seed) if not val: continue - bs_set(&cachebits[j // 64], j % 64) + cachebits[j // 64] == bs_set(cachebits[j // 64], j % 64) else: # BestSplitter # In practice, cache_size here should ALWAYS be 1 @@ -773,29 +773,27 @@ cdef class WeightedMedianCalculator: return self.samples.get_value_from_index(self.k-1) -cdef inline void bs_reset_all(BITSET_t *value) nogil: - value[0] = 0 +cdef inline BITSET_t bs_set(BITSET_t value, SIZE_t i) nogil: + return value | ( 1) << i -cdef inline void bs_set(BITSET_t *value, SIZE_t i) nogil: - value[0] |= ( 1) << i +cdef inline BITSET_t bs_reset(BITSET_t value, SIZE_t i) nogil: + return value & ~(( 1) << i) -cdef inline void bs_reset(BITSET_t *value, SIZE_t i) nogil: - value[0] &= ~(( 1) << i) +cdef inline BITSET_t bs_flip(BITSET_t value, SIZE_t i) nogil: + return value ^ ( 1) << i -cdef inline void bs_flip(BITSET_t *value, SIZE_t i) nogil: - value[0] ^= ( 1) << i - -cdef inline void bs_flip_all(BITSET_t *value, SIZE_t n_low_bits) nogil: - value[0] = (~value[0]) & ((~( 0)) >> (64 - n_low_bits)) +cdef inline BITSET_t bs_flip_all(BITSET_t value, SIZE_t n_low_bits) nogil: + return (~value) & ((~( 0)) >> (64 - n_low_bits)) cdef inline bint bs_get(BITSET_t value, SIZE_t i) nogil: return (value >> i) & ( 1) -cdef inline void bs_from_template(BITSET_t *value, UINT64_t template, - INT32_t *cat_offs, - SIZE_t ncats_present) nogil: +cdef inline BITSET_t bs_from_template(UINT64_t template, + INT32_t *cat_offs, + SIZE_t ncats_present) nogil: cdef SIZE_t i - value[0] = 0 + cdef BITSET_t value = 0 for i in range(ncats_present): - value[0] |= (template & - (( 1) << i)) << cat_offs[i] + value |= (template & + (( 1) << i)) << cat_offs[i] + return value From 5d04b1853cac5c087c6fc285356df971394e56e6 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 6 Feb 2019 16:36:48 +0100 Subject: [PATCH 53/54] fix SplitValue description --- sklearn/tree/_utils.pxd | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 121cdaafc8418..7d2cf332be241 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -24,24 +24,23 @@ ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer ctypedef UINT64_t BITSET_t ctypedef union SplitValue: - # Union type to generalize the concept of a threshold to - # categorical features. For non-categorical features, use the - # threshold member. It acts just as before, where feature values - # less than or equal to the threshold go left, and values greater - # than the threshold go right. + # Union type to generalize the concept of a threshold to categorical + # features. The floating point view, i.e. ``SplitValue.threshold`` is used + # for numerical features, where feature values less than or equal to the + # threshold go left, and values greater than the threshold go right. # - # For categorical features, use the cat_split member. It works in - # one of two ways, indicated by the value of its least significant - # bit (LSB). If the LSB is 0, then cat_split acts as a bitfield - # for up to 64 categories, sending samples left if the bit - # corresponding to their category is 1 or right if it is 0. If the - # LSB is 1, then the more significant 32 bits of cat_split is a - # random seed. To evaluate a sample, use the random seed to flip a - # coin (category_value + 1) times and send it left if the last - # flip gives 1; otherwise right. This second method allows up to - # 2**31 category values, but can only be used for RandomSplitter. + # For categorical features, the BITSET_t view (`SplitValue.cat_split``) is + # used. It works in one of two ways, indicated by the value of its least + # significant bit (LSB). If the LSB is 0, then cat_split acts as a bitfield + # for up to 64 categories, sending samples left if the bit corresponding to + # their category is 1 or right if it is 0. If the LSB is 1, then the most + # significant 32 bits of cat_split make a random seed. To evaluate a + # sample, use the random seed to flip a coin (category_value + 1) times and + # send it left if the last flip gives 1; otherwise right. This second + # method allows up to 2**31 category values, but can only be used for + # RandomSplitter. DOUBLE_t threshold - UINT64_t cat_split + BITSET_t cat_split ctypedef struct SplitRecord: From f44152ba8040fa4e1c1ba131c7e2891ff39588c9 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 25 Feb 2019 16:44:26 +0100 Subject: [PATCH 54/54] fix silly == bug --- sklearn/tree/_utils.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 619b0937a77e9..ecfcbdf308c52 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -133,7 +133,7 @@ cdef inline void setup_cat_cache(BITSET_t* cachebits, BITSET_t cat_split, val = rand_int(0, 2, &rng_seed) if not val: continue - cachebits[j // 64] == bs_set(cachebits[j // 64], j % 64) + cachebits[j // 64] = bs_set(cachebits[j // 64], j % 64) else: # BestSplitter # In practice, cache_size here should ALWAYS be 1