Thanks to visit codestin.com
Credit goes to github.com

Skip to content

NOCATS: Categorical splits for tree-based learners #4899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 47 additions & 18 deletions sklearn/ensemble/_gradient_boosting.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ from scipy.sparse import csr_matrix

from sklearn.tree._tree cimport Node
from sklearn.tree._tree cimport Tree
from sklearn.tree._tree cimport CategoryCacheMgr
from sklearn.tree._tree cimport DTYPE_t
from sklearn.tree._tree cimport SIZE_t
from sklearn.tree._tree cimport INT32_t
from sklearn.tree._tree cimport UINT32_t
from sklearn.tree._utils cimport safe_realloc
from sklearn.tree._utils cimport goes_left

ctypedef np.int32_t int32
ctypedef np.float64_t float64
Expand All @@ -48,6 +51,8 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X,
Py_ssize_t K,
Py_ssize_t n_samples,
Py_ssize_t n_features,
INT32_t* n_categories,
UINT32_t** cachebits,
float64 *out):
"""Predicts output for regression tree and stores it in ``out[i, k]``.

Expand Down Expand Up @@ -82,20 +87,32 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X,
``n_samples == X.shape[0]``.
n_features : int
The number of features; ``n_samples == X.shape[1]``.
n_categories : INT32_t pointer
Array of length n_features containing the number of categories
(for categorical features) or -1 (for non-categorical features)
cachebits : UINT32_t pointer pointer
Array of length node_count containing category cache buffers
for categorical features
out : np.float64_t pointer
The pointer to the data array where the predictions are stored.
``out`` is assumed to be a two-dimensional array of
shape ``(n_samples, K)``.
"""
cdef Py_ssize_t i
cdef Node *node
cdef UINT32_t* node_cache

for i in range(n_samples):
node = root_node
node_cache = cachebits[0]
# While node not a leaf
while node.left_child != TREE_LEAF:
if X[i * n_features + node.feature] <= node.threshold:
if goes_left(X[i * n_features + node.feature], node.split_value,
n_categories[node.feature], node_cache):
node_cache = cachebits[node.left_child]
node = root_node + node.left_child
else:
node_cache = cachebits[node.right_child]
node = root_node + node.right_child
out[i * K + k] += scale * value[node - root_node]

Expand Down Expand Up @@ -130,8 +147,8 @@ def _predict_regression_tree_stages_sparse(np.ndarray[object, ndim=2] estimators
cdef Tree tree
cdef Node** nodes = NULL
cdef double** values = NULL
safe_realloc(&nodes, n_stages * n_outputs)
safe_realloc(&values, n_stages * n_outputs)
safe_realloc(<void ***>&nodes, n_stages * n_outputs, sizeof(void*))
safe_realloc(<void ***>&values, n_stages * n_outputs, sizeof(void*))
for stage_i in range(n_stages):
for output_i in range(n_outputs):
tree = estimators[stage_i, output_i].tree_
Expand All @@ -147,8 +164,8 @@ def _predict_regression_tree_stages_sparse(np.ndarray[object, ndim=2] estimators
# which features are nonzero in the present sample.
cdef SIZE_t* feature_to_sample = NULL

safe_realloc(&X_sample, n_features)
safe_realloc(&feature_to_sample, n_features)
safe_realloc(&X_sample, n_features, sizeof(DTYPE_t))
safe_realloc(&feature_to_sample, n_features, sizeof(SIZE_t))

memset(feature_to_sample, -1, n_features * sizeof(SIZE_t))

Expand All @@ -174,7 +191,7 @@ def _predict_regression_tree_stages_sparse(np.ndarray[object, ndim=2] estimators
else:
feature_value = 0.

if feature_value <= node.threshold:
if feature_value <= node.split_value.threshold:
node = root_node + node.left_child
else:
node = root_node + node.right_child
Expand Down Expand Up @@ -213,13 +230,18 @@ def predict_stages(np.ndarray[object, ndim=2] estimators,
for k in range(K):
tree = estimators[i, k].tree_

# Make category cache buffers for this tree's nodes
cache_mgr = CategoryCacheMgr()
cache_mgr.populate(tree.nodes, tree.node_count, tree.n_categories)

# avoid buffer validation by casting to ndarray
# and get data pointer
# need brackets because of casting operator priority
_predict_regression_tree_inplace_fast_dense(
<DTYPE_t*> (<np.ndarray> X).data,
tree.nodes, tree.value,
scale, k, K, X.shape[0], X.shape[1],
tree.n_categories, cache_mgr.bits,
<float64 *> (<np.ndarray> out).data)
## out += scale * tree.predict(X).reshape((X.shape[0], 1))

Expand Down Expand Up @@ -290,27 +312,34 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X,
cdef SIZE_t node_count = tree.node_count

cdef SIZE_t stack_capacity = node_count * 2
cdef Node **node_stack
cdef double[::1] weight_stack = np_ones((stack_capacity,), dtype=np_float64)
cdef SIZE_t stack_size = 1
cdef double left_sample_frac
cdef double current_weight
cdef double total_weight = 0.0
cdef Node *current_node
underlying_stack = np_zeros((stack_capacity,), dtype=np.intp)
node_stack = <Node **>(<np.ndarray> underlying_stack).data
cdef SIZE_t[::1] node_stack = np_zeros((stack_capacity,), dtype=np.intp)
cdef UINT32_t** cachebits
cdef UINT32_t* node_cache

# Make category cache buffers for this tree's nodes
cache_mgr = CategoryCacheMgr()
cache_mgr.populate(root_node, node_count, tree.n_categories)
cachebits = cache_mgr.bits

for i in range(X.shape[0]):
# init stacks for new example
stack_size = 1
node_stack[0] = root_node
node_stack[0] = 0
node_cache = cachebits[0]
weight_stack[0] = 1.0
total_weight = 0.0

while stack_size > 0:
# get top node on stack
stack_size -= 1
current_node = node_stack[stack_size]
current_node = root_node + node_stack[stack_size]
node_cache = cachebits[node_stack[stack_size]]

if current_node.left_child == TREE_LEAF:
out[i] += weight_stack[stack_size] * value[current_node - root_node] * \
Expand All @@ -322,21 +351,21 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X,
if feature_index != -1:
# split feature in target set
# push left or right child on stack
if X[i, feature_index] <= current_node.threshold:
if goes_left(X[i, feature_index], current_node.split_value,
tree.n_categories[current_node.feature],
node_cache):
# left
node_stack[stack_size] = (root_node +
current_node.left_child)
node_stack[stack_size] = current_node.left_child
else:
# right
node_stack[stack_size] = (root_node +
current_node.right_child)
node_stack[stack_size] = current_node.right_child
stack_size += 1
else:
# split feature in complement set
# push both children onto stack

# push left child
node_stack[stack_size] = root_node + current_node.left_child
node_stack[stack_size] = current_node.left_child
current_weight = weight_stack[stack_size]
left_sample_frac = root_node[current_node.left_child].n_node_samples / \
<double>current_node.n_node_samples
Expand All @@ -351,7 +380,7 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X,
stack_size +=1

# push right child
node_stack[stack_size] = root_node + current_node.right_child
node_stack[stack_size] = current_node.right_child
weight_stack[stack_size] = current_weight * \
(1.0 - left_sample_frac)
stack_size +=1
Expand Down
63 changes: 58 additions & 5 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,17 @@ class RandomForestClassifier(ForestClassifier):

.. versionadded:: 0.18

categorical : array-like or str
Array of feature indices, boolean array of length n_features,
``'all'`` or ``'none'``. Indicates which features should be
considered as categorical rather than ordinal. For decision trees,
the maximum number of categories is 64. In practice, the limit will
often be lower because the process of searching for the best possible
split grows exponentially with the number of categories. However, a
shortcut due to Breiman (1984) is used when fitting data with binary
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
the runtime is linear in the number of categories.

bootstrap : boolean, optional (default=True)
Whether bootstrap samples are used when building trees.

Expand Down Expand Up @@ -908,6 +919,7 @@ def __init__(self,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
categorical="none",
bootstrap=True,
oob_score=False,
n_jobs=1,
Expand All @@ -921,7 +933,7 @@ def __init__(self,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
"random_state", "categorical"),
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
Expand All @@ -938,6 +950,7 @@ def __init__(self,
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split
self.categorical = categorical


class RandomForestRegressor(ForestRegressor):
Expand Down Expand Up @@ -1025,6 +1038,17 @@ class RandomForestRegressor(ForestRegressor):

.. versionadded:: 0.18

categorical : array-like or str
Array of feature indices, boolean array of length n_features,
``'all'`` or ``'none'``. Indicates which features should be
considered as categorical rather than ordinal. For decision trees,
the maximum number of categories is 64. In practice, the limit will
often be lower because the process of searching for the best possible
split grows exponentially with the number of categories. However, a
shortcut due to Breiman (1984) is used when fitting data with binary
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
the runtime is linear in the number of categories.

bootstrap : boolean, optional (default=True)
Whether bootstrap samples are used when building trees.

Expand Down Expand Up @@ -1089,6 +1113,7 @@ def __init__(self,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
categorical="none",
bootstrap=True,
oob_score=False,
n_jobs=1,
Expand All @@ -1101,7 +1126,7 @@ def __init__(self,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
"random_state", "categorical"),
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
Expand All @@ -1117,6 +1142,7 @@ def __init__(self,
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split
self.categorical = categorical


class ExtraTreesClassifier(ForestClassifier):
Expand Down Expand Up @@ -1197,6 +1223,13 @@ class ExtraTreesClassifier(ForestClassifier):

.. versionadded:: 0.18

categorical : array-like or str
Array of feature indices, boolean array of length n_features,
``'all'`` or ``'none'``. Indicates which features should be
considered as categorical rather than ordinal. Extra-random trees
have an upper limit of :math:`2^{31}` categories, and runtimes
linear in the number of categories.

bootstrap : boolean, optional (default=False)
Whether bootstrap samples are used when building trees.

Expand Down Expand Up @@ -1293,6 +1326,7 @@ def __init__(self,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
categorical="none",
bootstrap=False,
oob_score=False,
n_jobs=1,
Expand All @@ -1306,7 +1340,7 @@ def __init__(self,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
"random_state", "categorical"),
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
Expand All @@ -1323,6 +1357,7 @@ def __init__(self,
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split
self.categorical = categorical


class ExtraTreesRegressor(ForestRegressor):
Expand Down Expand Up @@ -1408,6 +1443,13 @@ class ExtraTreesRegressor(ForestRegressor):

.. versionadded:: 0.18

categorical : array-like or str
Array of feature indices, boolean array of length n_features,
``'all'`` or ``'none'``. Indicates which features should be
considered as categorical rather than ordinal. Extra-random trees
have an upper limit of :math:`2^{31}` categories, and runtimes
linear in the number of categories.

bootstrap : boolean, optional (default=False)
Whether bootstrap samples are used when building trees.

Expand Down Expand Up @@ -1473,6 +1515,7 @@ def __init__(self,
max_features="auto",
max_leaf_nodes=None,
min_impurity_split=1e-7,
categorical="none",
bootstrap=False,
oob_score=False,
n_jobs=1,
Expand All @@ -1485,7 +1528,7 @@ def __init__(self,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
"random_state", "categorical"),
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
Expand All @@ -1501,6 +1544,7 @@ def __init__(self,
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split
self.categorical = categorical


class RandomTreesEmbedding(BaseForest):
Expand Down Expand Up @@ -1566,6 +1610,13 @@ class RandomTreesEmbedding(BaseForest):

.. versionadded:: 0.18

categorical : array-like or str
Array of feature indices, boolean array of length n_features,
``'all'`` or ``'none'``. Indicates which features should be
considered as categorical rather than ordinal. Extra-random trees
have an upper limit of :math:`2^{31}` categories, and runtimes
linear in the number of categories.

sparse_output : bool, optional (default=True)
Whether or not to return a sparse CSR matrix, as default behavior,
or to return a dense array compatible with dense pipeline operators.
Expand Down Expand Up @@ -1611,6 +1662,7 @@ def __init__(self,
min_weight_fraction_leaf=0.,
max_leaf_nodes=None,
min_impurity_split=1e-7,
categorical="none",
sparse_output=True,
n_jobs=1,
random_state=None,
Expand All @@ -1622,7 +1674,7 @@ def __init__(self,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
"random_state", "categorical"),
bootstrap=False,
oob_score=False,
n_jobs=n_jobs,
Expand All @@ -1639,6 +1691,7 @@ def __init__(self,
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split
self.sparse_output = sparse_output
self.categorical = categorical

def _set_oob_score(self, X, y):
raise NotImplementedError("OOB score not supported by tree embedding")
Expand Down
Loading