From 1881ececb3aa8e95db0d73d691c640a50cc64422 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 29 Jun 2017 15:57:03 -0400 Subject: [PATCH 01/61] add reingold tillford tree layout algorithm --- sklearn/tree/_reingold_tilford.py | 172 ++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) create mode 100644 sklearn/tree/_reingold_tilford.py diff --git a/sklearn/tree/_reingold_tilford.py b/sklearn/tree/_reingold_tilford.py new file mode 100644 index 0000000000000..9aa8820fa98ad --- /dev/null +++ b/sklearn/tree/_reingold_tilford.py @@ -0,0 +1,172 @@ +# taken from https://github.com/llimllib/pymag-trees/blob/master/buchheim.py + + +class DrawTree(object): + def __init__(self, tree, parent=None, depth=0, number=1): + self.x = -1. + self.y = depth + self.tree = tree + self.children = [DrawTree(c, self, depth + 1, i + 1) + for i, c + in enumerate(tree.children)] + self.parent = parent + self.thread = None + self.mod = 0 + self.ancestor = self + self.change = self.shift = 0 + self._lmost_sibling = None + # this is the number of the node in its group of siblings 1..n + self.number = number + + def left(self): + return self.thread or len(self.children) and self.children[0] + + def right(self): + return self.thread or len(self.children) and self.children[-1] + + def lbrother(self): + n = None + if self.parent: + for node in self.parent.children: + if node == self: + return n + else: + n = node + return n + + def get_lmost_sibling(self): + if not self._lmost_sibling and self.parent and self != \ + self.parent.children[0]: + self._lmost_sibling = self.parent.children[0] + return self._lmost_sibling + lmost_sibling = property(get_lmost_sibling) + + def __str__(self): + return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod) + + def __repr__(self): + return self.__str__() + + +def buchheim(tree): + dt = firstwalk(DrawTree(tree)) + min = second_walk(dt) + if min < 0: + third_walk(dt, -min) + return dt + + +def third_walk(tree, n): + tree.x += n + for c in tree.children: + third_walk(c, n) + + +def firstwalk(v, distance=1.): + if len(v.children) == 0: + if v.lmost_sibling: + v.x = v.lbrother().x + distance + else: + v.x = 0. + else: + default_ancestor = v.children[0] + for w in v.children: + firstwalk(w) + default_ancestor = apportion(w, default_ancestor, distance) + # print("finished v =", v.tree, "children") + execute_shifts(v) + + midpoint = (v.children[0].x + v.children[-1].x) / 2 + + ell = v.children[0] + arr = v.children[-1] + w = v.lbrother() + if w: + v.x = w.x + distance + v.mod = v.x - midpoint + else: + v.x = midpoint + return v + + +def apportion(v, default_ancestor, distance): + w = v.lbrother() + if w is not None: + # in buchheim notation: + # i == inner; o == outer; r == right; l == left; r = +; l = - + vir = vor = v + vil = w + vol = v.lmost_sibling + sir = sor = v.mod + sil = vil.mod + sol = vol.mod + while vil.right() and vir.left(): + vil = vil.right() + vir = vir.left() + vol = vol.left() + vor = vor.right() + vor.ancestor = v + shift = (vil.x + sil) - (vir.x + sir) + distance + if shift > 0: + move_subtree(ancestor(vil, v, default_ancestor), v, shift) + sir = sir + shift + sor = sor + shift + sil += vil.mod + sir += vir.mod + sol += vol.mod + sor += vor.mod + if vil.right() and not vor.right(): + vor.thread = vil.right() + vor.mod += sil - sor + else: + if vir.left() and not vol.left(): + vol.thread = vir.left() + vol.mod += sir - sol + default_ancestor = v + return default_ancestor + + +def move_subtree(wl, wr, shift): + subtrees = wr.number - wl.number + # print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees, + # 'shift', shift) + # print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees + wr.change -= shift / subtrees + wr.shift += shift + wl.change += shift / subtrees + wr.x += shift + wr.mod += shift + + +def execute_shifts(v): + shift = change = 0 + for w in v.children[::-1]: + # print("shift:", w, shift, w.change) + w.x += shift + w.mod += shift + change += w.change + shift += w.shift + change + + +def ancestor(vil, v, default_ancestor): + # the relevant text is at the bottom of page 7 of + # "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al, + # (2002) + # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.16.8757&rep=rep1&type=pdf + if vil.ancestor in v.parent.children: + return vil.ancestor + else: + return default_ancestor + + +def second_walk(v, m=0, depth=0, min=None): + v.x += m + v.y = depth + + if min is None or v.x < min: + min = v.x + + for w in v.children: + min = second_walk(w, m + v.mod, depth + 1, min) + + return min From bd7d022d7c8f4388b69768d4e0a581c4961d5d3d Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 29 Jun 2017 16:13:28 -0400 Subject: [PATCH 02/61] add first silly implementation of matplotlib based plotting for trees --- sklearn/tree/_reingold_tilford.py | 32 ++++++ sklearn/tree/export.py | 160 ++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+) diff --git a/sklearn/tree/_reingold_tilford.py b/sklearn/tree/_reingold_tilford.py index 9aa8820fa98ad..4790a90dc6546 100644 --- a/sklearn/tree/_reingold_tilford.py +++ b/sklearn/tree/_reingold_tilford.py @@ -170,3 +170,35 @@ def second_walk(v, m=0, depth=0, min=None): min = second_walk(w, m + v.mod, depth + 1, min) return min + + +# my stuff + +class Tree(object): + def __init__(self, node="", *children): + self.node = node + self.width = len(node) + if children: + self.children = children + else: + self.children = [] + + def __str__(self): + return "%s" % (self.node) + + def __repr__(self): + return "%s" % (self.node) + + def __getitem__(self, key): + if isinstance(key, int) or isinstance(key, slice): + return self.children[key] + if isinstance(key, str): + for child in self.children: + if child.node == key: + return child + + def __iter__(self): + return self.children.__iter__() + + def __len__(self): + return len(self.children) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index f526c771af047..f27611845bb7e 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -21,6 +21,7 @@ from . import _criterion from . import _tree +from ._reingold_tilford import buchheim, Tree def _color_brew(n): @@ -68,9 +69,168 @@ def _color_brew(n): class Sentinel(object): def __repr__(): return '"tree.dot"' + + SENTINEL = Sentinel() +def node_to_str(tree, node_id, criterion): + # stupid copy & paste with a few adjustments + label = 'all' + feature_names = None + class_names = None + label = 'all' + impurity = True + node_ids = False + proportion = False + special_characters = False + precision = 3 + # Generate the node content string + if tree.n_outputs == 1: + value = tree.value[node_id][0, :] + else: + value = tree.value[node_id] + # Should labels be shown? + labels = (label == 'root' and node_id == 0) or label == 'all' + + # PostScript compatibility for special characters + if special_characters: + characters = ['#', '', '', '≤', '
', '>'] + node_string = '<' + else: + characters = ['#', '[', ']', '<=', '\n', ''] + node_string = '' + + # Write node ID + if node_ids: + if labels: + node_string += 'node ' + node_string += characters[0] + str(node_id) + characters[4] + + # Write decision criteria + if tree.children_left[node_id] != _tree.TREE_LEAF: + # Always write node decision criteria, except for leaves + if feature_names is not None: + feature = feature_names[tree.feature[node_id]] + else: + feature = "X%s%s%s" % (characters[1], + tree.feature[node_id], + characters[2]) + node_string += '%s %s %s%s' % (feature, + characters[3], + round(tree.threshold[node_id], + precision), + characters[4]) + + # Write impurity + if impurity: + if isinstance(criterion, _criterion.FriedmanMSE): + criterion = "friedman_mse" + elif not isinstance(criterion, six.string_types): + criterion = "impurity" + if labels: + node_string += '%s = ' % criterion + node_string += (str(round(tree.impurity[node_id], precision)) + + characters[4]) + + # Write node sample count + if labels: + node_string += 'samples = ' + if proportion: + percent = (100. * tree.n_node_samples[node_id] / + float(tree.n_node_samples[0])) + node_string += (str(round(percent, 1)) + '%' + + characters[4]) + else: + node_string += (str(tree.n_node_samples[node_id]) + + characters[4]) + + # Write node class distribution / regression value + if proportion and tree.n_classes[0] != 1: + # For classification this will show the proportion of samples + value = value / tree.weighted_n_node_samples[node_id] + if labels: + node_string += 'value = ' + if tree.n_classes[0] == 1: + # Regression + value_text = np.around(value, precision) + elif proportion: + # Classification + value_text = np.around(value, precision) + elif np.all(np.equal(np.mod(value, 1), 0)): + # Classification without floating-point weights + value_text = value.astype(int) + else: + # Classification with floating-point weights + value_text = np.around(value, precision) + # Strip whitespace + value_text = str(value_text.astype('S32')).replace("b'", "'") + value_text = value_text.replace("' '", ", ").replace("'", "") + if tree.n_classes[0] == 1 and tree.n_outputs == 1: + value_text = value_text.replace("[", "").replace("]", "") + value_text = value_text.replace("\n ", characters[4]) + node_string += value_text + characters[4] + + # Write node majority class + if (class_names is not None and + tree.n_classes[0] != 1 and + tree.n_outputs == 1): + # Only done for single-output classification trees + if labels: + node_string += 'class = ' + if class_names is not True: + class_name = class_names[np.argmax(value)] + else: + class_name = "y%s%s%s" % (characters[1], + np.argmax(value), + characters[2]) + node_string += class_name + + # Clean up any trailing newlines + if node_string[-2:] == '\n': + node_string = node_string[:-2] + if node_string[-5:] == '
': + node_string = node_string[:-5] + + return node_string + characters[5] + + +def _make_tree(node_id, et): + # traverses _tree.Tree recursively, builds intermediate "Tree" object + name = node_to_str(et, 0, criterion='entropy') + if (et.children_left[node_id] != et.children_right[node_id]): + children = [_make_tree(et.children_left[node_id], et), _make_tree( + et.children_right[node_id], et)] + else: + return Tree(name) + return Tree(name, *children) + + +def plot_tree(estimator): + import matplotlib.pyplot as plt + bbox_args = dict(boxstyle="round", fc="0.8") + arrow_args = dict(arrowstyle="-") + + def draw_nodes(node, scale=1, zorder=0): + # 2 - is a hack to for not creating empty space. FIXME + if node.parent is None: + plt.annotate(node.tree, (node.x * scale, (2 - node.y) * scale), + bbox=bbox_args, ha='center', va='bottom', + zorder=zorder, xycoords='axes points') + else: + plt.annotate(node.tree, (node.parent.x * scale, (2 - node.parent.y) + * scale), + (node.x * scale, (2 - node.y) * scale), + bbox=bbox_args, arrowprops=arrow_args, ha='center', + va='bottom', zorder=zorder, xycoords='axes points') + for child in node.children: + draw_nodes(child, scale=scale, zorder=zorder - 1) + + my_tree = _make_tree(0, estimator.tree_) + dt = buchheim(my_tree) + draw_nodes(dt, scale=110) + + def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, From 287c1d22c721392bdde40e9a6a4546747c3e6187 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 29 Jun 2017 17:19:28 -0400 Subject: [PATCH 03/61] object oriented design for export_graphviz so it can be extended --- sklearn/tree/export.py | 471 ++++++++++++++++++++++------------------- 1 file changed, 258 insertions(+), 213 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index f27611845bb7e..9678ef8a18e88 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -231,111 +231,103 @@ def draw_nodes(node, scale=1, zorder=0): draw_nodes(dt, scale=110) -def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None, - feature_names=None, class_names=None, label='all', - filled=False, leaves_parallel=False, impurity=True, - node_ids=False, proportion=False, rotate=False, - rounded=False, special_characters=False, precision=3): - """Export a decision tree in DOT format. - - This function generates a GraphViz representation of the decision tree, - which is then written into `out_file`. Once exported, graphical renderings - can be generated using, for example:: - - $ dot -Tps tree.dot -o tree.ps (PostScript format) - $ dot -Tpng tree.dot -o tree.png (PNG format) - - The sample counts that are shown are weighted with any sample_weights that - might be present. - - Read more in the :ref:`User Guide `. - - Parameters - ---------- - decision_tree : decision tree classifier - The decision tree to be exported to GraphViz. - - out_file : file object or string, optional (default='tree.dot') - Handle or name of the output file. If ``None``, the result is - returned as a string. This will the default from version 0.20. - - max_depth : int, optional (default=None) - The maximum depth of the representation. If None, the tree is fully - generated. - - feature_names : list of strings, optional (default=None) - Names of each of the features. - - class_names : list of strings, bool or None, optional (default=None) - Names of each of the target classes in ascending numerical order. - Only relevant for classification and not supported for multi-output. - If ``True``, shows a symbolic representation of the class name. - - label : {'all', 'root', 'none'}, optional (default='all') - Whether to show informative labels for impurity, etc. - Options include 'all' to show at every node, 'root' to show only at - the top root node, or 'none' to not show at any node. - - filled : bool, optional (default=False) - When set to ``True``, paint nodes to indicate majority class for - classification, extremity of values for regression, or purity of node - for multi-output. - - leaves_parallel : bool, optional (default=False) - When set to ``True``, draw all leaf nodes at the bottom of the tree. - - impurity : bool, optional (default=True) - When set to ``True``, show the impurity at each node. - - node_ids : bool, optional (default=False) - When set to ``True``, show the ID number on each node. - - proportion : bool, optional (default=False) - When set to ``True``, change the display of 'values' and/or 'samples' - to be proportions and percentages respectively. - - rotate : bool, optional (default=False) - When set to ``True``, orient tree left to right rather than top-down. - - rounded : bool, optional (default=False) - When set to ``True``, draw node boxes with rounded corners and use - Helvetica fonts instead of Times-Roman. - - special_characters : bool, optional (default=False) - When set to ``False``, ignore special characters for PostScript - compatibility. - - precision : int, optional (default=3) - Number of digits of precision for floating point in the values of - impurity, threshold and value attributes of each node. +class _DOTTreeExporter(object): + def __init__(self, out_file=SENTINEL, max_depth=None, + feature_names=None, class_names=None, label='all', + filled=False, leaves_parallel=False, impurity=True, + node_ids=False, proportion=False, rotate=False, rounded=False, + special_characters=False, precision=3): + self.out_file = out_file + self.max_depth = max_depth + self.feature_names = feature_names + self.class_names = class_names + self.label = label + self.filled = filled + self.leaves_parallel = leaves_parallel + self.impurity = impurity + self.node_ids = node_ids + self.proportion = proportion + self.rotate = rotate + self.rounded = rounded + self.special_characters = special_characters + self.precision = precision + + # validate + if isinstance(precision, Integral): + if precision < 0: + raise ValueError("'precision' should be greater or equal to 0." + " Got {} instead.".format(precision)) + else: + raise ValueError("'precision' should be an integer. Got {}" + " instead.".format(type(precision))) - Returns - ------- - dot_data : string - String representation of the input tree in GraphViz dot format. - Only returned if ``out_file`` is None. + # The depth of each node for plotting with 'leaf' option + self.ranks = {'leaves': []} + # The colors to render each node with + self.colors = {'bounds': None} - .. versionadded:: 0.18 + def export(self, decision_tree): + # Check length of feature_names before getting into the tree node + # Raise error if length of feature_names does not match + # n_features_ in the decision_tree + if self.feature_names is not None: + if len(self.feature_names) != decision_tree.n_features_: + raise ValueError("Length of feature_names, %d " + "does not match number of features, %d" + % (len(self.feature_names), + decision_tree.n_features_)) + # each part writes to out_file + self.head() + # Now recurse the tree and add node & edge attributes + if isinstance(decision_tree, _tree.Tree): + self.recurse(decision_tree, 0, criterion="impurity") + else: + self.recurse(decision_tree.tree_, 0, + criterion=decision_tree.criterion) - Examples - -------- - >>> from sklearn.datasets import load_iris - >>> from sklearn import tree + self.tail() - >>> clf = tree.DecisionTreeClassifier() - >>> iris = load_iris() + def tail(self): + # If required, draw leaf nodes at same depth as each other + if self.leaves_parallel: + for rank in sorted(self.ranks): + self.out_file.write( + "{rank=same ; " + + "; ".join(r for r in self.ranks[rank]) + "} ;\n") + self.out_file.write("}") - >>> clf = clf.fit(iris.data, iris.target) - >>> tree.export_graphviz(clf, - ... out_file='tree.dot') # doctest: +SKIP + def head(self): + self.out_file.write('digraph Tree {\n') - """ + # Specify node aesthetics + self.out_file.write('node [shape=box') + rounded_filled = [] + if self.filled: + rounded_filled.append('filled') + if self.rounded: + rounded_filled.append('rounded') + if len(rounded_filled) > 0: + self.out_file.write( + ', style="%s", color="black"' + % ", ".join(rounded_filled)) + if self.rounded: + self.out_file.write(', fontname=helvetica') + self.out_file.write('] ;\n') - def get_color(value): + # Specify graph & edge aesthetics + if self.leaves_parallel: + self.out_file.write( + 'graph [ranksep=equally, splines=polyline] ;\n') + if self.rounded: + self.out_file.write('edge [fontname=helvetica] ;\n') + if self.rotate: + self.out_file.write('rankdir=LR ;\n') + + def get_color(self, value): # Find the appropriate color & intensity for a node - if colors['bounds'] is None: + if self.colors['bounds'] is None: # Classification tree - color = list(colors['rgb'][np.argmax(value)]) + color = list(self.colors['rgb'][np.argmax(value)]) sorted_values = sorted(value, reverse=True) if len(sorted_values) == 1: alpha = 0 @@ -345,10 +337,10 @@ def get_color(value): (1 - sorted_values[1]), 0)) else: # Regression tree or multi-output - color = list(colors['rgb'][0]) - alpha = int(np.round(255 * ((value - colors['bounds'][0]) / - (colors['bounds'][1] - - colors['bounds'][0])), 0)) + color = list(self.colors['rgb'][0]) + alpha = int(np.round(255 * ((value - self.colors['bounds'][0]) / + (self.colors['bounds'][1] - + self.colors['bounds'][0])), 0)) # Return html color code in #RRGGBBAA format color.append(alpha) @@ -358,7 +350,7 @@ def get_color(value): return '#' + ''.join(color) - def node_to_str(tree, node_id, criterion): + def node_to_str(self, tree, node_id, criterion): # Generate the node content string if tree.n_outputs == 1: value = tree.value[node_id][0, :] @@ -366,10 +358,10 @@ def node_to_str(tree, node_id, criterion): value = tree.value[node_id] # Should labels be shown? - labels = (label == 'root' and node_id == 0) or label == 'all' + labels = (self.label == 'root' and node_id == 0) or self.label == 'all' # PostScript compatibility for special characters - if special_characters: + if self.special_characters: characters = ['#', '', '', '≤', '
', '>'] node_string = '<' else: @@ -377,7 +369,7 @@ def node_to_str(tree, node_id, criterion): node_string = '"' # Write node ID - if node_ids: + if self.node_ids: if labels: node_string += 'node ' node_string += characters[0] + str(node_id) + characters[4] @@ -385,8 +377,8 @@ def node_to_str(tree, node_id, criterion): # Write decision criteria if tree.children_left[node_id] != _tree.TREE_LEAF: # Always write node decision criteria, except for leaves - if feature_names is not None: - feature = feature_names[tree.feature[node_id]] + if self.feature_names is not None: + feature = self.feature_names[tree.feature[node_id]] else: feature = "X%s%s%s" % (characters[1], tree.feature[node_id], @@ -394,24 +386,24 @@ def node_to_str(tree, node_id, criterion): node_string += '%s %s %s%s' % (feature, characters[3], round(tree.threshold[node_id], - precision), + self.precision), characters[4]) # Write impurity - if impurity: + if self.impurity: if isinstance(criterion, _criterion.FriedmanMSE): criterion = "friedman_mse" elif not isinstance(criterion, six.string_types): criterion = "impurity" if labels: node_string += '%s = ' % criterion - node_string += (str(round(tree.impurity[node_id], precision)) + - characters[4]) + node_string += (str(round(tree.impurity[node_id], self.precision)) + + characters[4]) # Write node sample count if labels: node_string += 'samples = ' - if proportion: + if self.proportion: percent = (100. * tree.n_node_samples[node_id] / float(tree.n_node_samples[0])) node_string += (str(round(percent, 1)) + '%' + @@ -421,23 +413,23 @@ def node_to_str(tree, node_id, criterion): characters[4]) # Write node class distribution / regression value - if proportion and tree.n_classes[0] != 1: + if self.proportion and tree.n_classes[0] != 1: # For classification this will show the proportion of samples value = value / tree.weighted_n_node_samples[node_id] if labels: node_string += 'value = ' if tree.n_classes[0] == 1: # Regression - value_text = np.around(value, precision) - elif proportion: + value_text = np.around(value, self.precision) + elif self.proportion: # Classification - value_text = np.around(value, precision) + value_text = np.around(value, self.precision) elif np.all(np.equal(np.mod(value, 1), 0)): # Classification without floating-point weights value_text = value.astype(int) else: # Classification with floating-point weights - value_text = np.around(value, precision) + value_text = np.around(value, self.precision) # Strip whitespace value_text = str(value_text.astype('S32')).replace("b'", "'") value_text = value_text.replace("' '", ", ").replace("'", "") @@ -447,14 +439,14 @@ def node_to_str(tree, node_id, criterion): node_string += value_text + characters[4] # Write node majority class - if (class_names is not None and + if (self.class_names is not None and tree.n_classes[0] != 1 and tree.n_outputs == 1): # Only done for single-output classification trees if labels: node_string += 'class = ' - if class_names is not True: - class_name = class_names[np.argmax(value)] + if self.class_names is not True: + class_name = self.class_names[np.argmax(value)] else: class_name = "y%s%s%s" % (characters[1], np.argmax(value), @@ -469,7 +461,7 @@ def node_to_str(tree, node_id, criterion): return node_string + characters[5] - def recurse(tree, node_id, criterion, parent=None, depth=0): + def recurse(self, tree, node_id, criterion, parent=None, depth=0): if node_id == _tree.TREE_LEAF: raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF) @@ -477,34 +469,34 @@ def recurse(tree, node_id, criterion, parent=None, depth=0): right_child = tree.children_right[node_id] # Add node with description - if max_depth is None or depth <= max_depth: + if self.max_depth is None or depth <= self.max_depth: # Collect ranks for 'leaf' option in plot_options if left_child == _tree.TREE_LEAF: - ranks['leaves'].append(str(node_id)) - elif str(depth) not in ranks: - ranks[str(depth)] = [str(node_id)] + self.ranks['leaves'].append(str(node_id)) + elif str(depth) not in self.ranks: + self.ranks[str(depth)] = [str(node_id)] else: - ranks[str(depth)].append(str(node_id)) + self.ranks[str(depth)].append(str(node_id)) - out_file.write('%d [label=%s' - % (node_id, - node_to_str(tree, node_id, criterion))) + self.out_file.write( + '%d [label=%s' % (node_id, self.node_to_str(tree, node_id, + criterion))) - if filled: + if self.filled: # Fetch appropriate color for node - if 'rgb' not in colors: + if 'rgb' not in self.colors: # Initialize colors and bounds if required - colors['rgb'] = _color_brew(tree.n_classes[0]) + self.colors['rgb'] = _color_brew(tree.n_classes[0]) if tree.n_outputs != 1: # Find max and min impurities for multi-output - colors['bounds'] = (np.min(-tree.impurity), - np.max(-tree.impurity)) + self.colors['bounds'] = (np.min(-tree.impurity), + np.max(-tree.impurity)) elif (tree.n_classes[0] == 1 and len(np.unique(tree.value)) != 1): # Find max and min values in leaf nodes for regression - colors['bounds'] = (np.min(tree.value), - np.max(tree.value)) + self.colors['bounds'] = (np.min(tree.value), + np.max(tree.value)) if tree.n_outputs == 1: node_val = (tree.value[node_id][0, :] / tree.weighted_n_node_samples[node_id]) @@ -514,40 +506,144 @@ def recurse(tree, node_id, criterion, parent=None, depth=0): else: # If multi-output color node by impurity node_val = -tree.impurity[node_id] - out_file.write(', fillcolor="%s"' % get_color(node_val)) - out_file.write('] ;\n') + self.out_file.write(', fillcolor="%s"' + % self.get_color(node_val)) + self.out_file.write('] ;\n') if parent is not None: # Add edge to parent - out_file.write('%d -> %d' % (parent, node_id)) + self.out_file.write('%d -> %d' % (parent, node_id)) if parent == 0: # Draw True/False labels if parent is root node - angles = np.array([45, -45]) * ((rotate - .5) * -2) - out_file.write(' [labeldistance=2.5, labelangle=') + angles = np.array([45, -45]) * ((self.rotate - .5) * -2) + self.out_file.write(' [labeldistance=2.5, labelangle=') if node_id == 1: - out_file.write('%d, headlabel="True"]' % angles[0]) + self.out_file.write('%d, headlabel="True"]' % + angles[0]) else: - out_file.write('%d, headlabel="False"]' % angles[1]) - out_file.write(' ;\n') + self.out_file.write('%d, headlabel="False"]' % + angles[1]) + self.out_file.write(' ;\n') if left_child != _tree.TREE_LEAF: - recurse(tree, left_child, criterion=criterion, parent=node_id, - depth=depth + 1) - recurse(tree, right_child, criterion=criterion, parent=node_id, - depth=depth + 1) + self.recurse(tree, left_child, criterion=criterion, + parent=node_id, depth=depth + 1) + self.recurse(tree, right_child, criterion=criterion, + parent=node_id, depth=depth + 1) else: - ranks['leaves'].append(str(node_id)) + self.ranks['leaves'].append(str(node_id)) - out_file.write('%d [label="(...)"' % node_id) - if filled: + self.out_file.write('%d [label="(...)"' % node_id) + if self.filled: # color cropped nodes grey - out_file.write(', fillcolor="#C0C0C0"') - out_file.write('] ;\n' % node_id) + self.out_file.write(', fillcolor="#C0C0C0"') + self.out_file.write('] ;\n' % node_id) if parent is not None: # Add edge to parent - out_file.write('%d -> %d ;\n' % (parent, node_id)) + self.out_file.write('%d -> %d ;\n' % (parent, node_id)) + + +def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None, + feature_names=None, class_names=None, label='all', + filled=False, leaves_parallel=False, impurity=True, + node_ids=False, proportion=False, rotate=False, + rounded=False, special_characters=False, precision=3): + """Export a decision tree in DOT format. + + This function generates a GraphViz representation of the decision tree, + which is then written into `out_file`. Once exported, graphical renderings + can be generated using, for example:: + + $ dot -Tps tree.dot -o tree.ps (PostScript format) + $ dot -Tpng tree.dot -o tree.png (PNG format) + + The sample counts that are shown are weighted with any sample_weights that + might be present. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + decision_tree : decision tree classifier + The decision tree to be exported to GraphViz. + + out_file : file object or string, optional (default='tree.dot') + Handle or name of the output file. If ``None``, the result is + returned as a string. This will the default from version 0.20. + + max_depth : int, optional (default=None) + The maximum depth of the representation. If None, the tree is fully + generated. + + feature_names : list of strings, optional (default=None) + Names of each of the features. + + class_names : list of strings, bool or None, optional (default=None) + Names of each of the target classes in ascending numerical order. + Only relevant for classification and not supported for multi-output. + If ``True``, shows a symbolic representation of the class name. + + label : {'all', 'root', 'none'}, optional (default='all') + Whether to show informative labels for impurity, etc. + Options include 'all' to show at every node, 'root' to show only at + the top root node, or 'none' to not show at any node. + + filled : bool, optional (default=False) + When set to ``True``, paint nodes to indicate majority class for + classification, extremity of values for regression, or purity of node + for multi-output. + + leaves_parallel : bool, optional (default=False) + When set to ``True``, draw all leaf nodes at the bottom of the tree. + + impurity : bool, optional (default=True) + When set to ``True``, show the impurity at each node. + + node_ids : bool, optional (default=False) + When set to ``True``, show the ID number on each node. + + proportion : bool, optional (default=False) + When set to ``True``, change the display of 'values' and/or 'samples' + to be proportions and percentages respectively. + + rotate : bool, optional (default=False) + When set to ``True``, orient tree left to right rather than top-down. + + rounded : bool, optional (default=False) + When set to ``True``, draw node boxes with rounded corners and use + Helvetica fonts instead of Times-Roman. + + special_characters : bool, optional (default=False) + When set to ``False``, ignore special characters for PostScript + compatibility. + + precision : int, optional (default=3) + Number of digits of precision for floating point in the values of + impurity, threshold and value attributes of each node. + + Returns + ------- + dot_data : string + String representation of the input tree in GraphViz dot format. + Only returned if ``out_file`` is None. + + .. versionadded:: 0.18 + + Examples + -------- + >>> from sklearn.datasets import load_iris + >>> from sklearn import tree + + >>> clf = tree.DecisionTreeClassifier() + >>> iris = load_iris() + + >>> clf = clf.fit(iris.data, iris.target) + >>> tree.export_graphviz(clf, + ... out_file='tree.dot') # doctest: +SKIP + + """ check_is_fitted(decision_tree, 'tree_') own_file = False @@ -570,68 +666,17 @@ def recurse(tree, node_id, criterion, parent=None, depth=0): return_string = True out_file = six.StringIO() - if isinstance(precision, Integral): - if precision < 0: - raise ValueError("'precision' should be greater or equal to 0." - " Got {} instead.".format(precision)) - else: - raise ValueError("'precision' should be an integer. Got {}" - " instead.".format(type(precision))) - - # Check length of feature_names before getting into the tree node - # Raise error if length of feature_names does not match - # n_features_ in the decision_tree - if feature_names is not None: - if len(feature_names) != decision_tree.n_features_: - raise ValueError("Length of feature_names, %d " - "does not match number of features, %d" - % (len(feature_names), - decision_tree.n_features_)) - - # The depth of each node for plotting with 'leaf' option - ranks = {'leaves': []} - # The colors to render each node with - colors = {'bounds': None} - - out_file.write('digraph Tree {\n') - - # Specify node aesthetics - out_file.write('node [shape=box') - rounded_filled = [] - if filled: - rounded_filled.append('filled') - if rounded: - rounded_filled.append('rounded') - if len(rounded_filled) > 0: - out_file.write(', style="%s", color="black"' - % ", ".join(rounded_filled)) - if rounded: - out_file.write(', fontname=helvetica') - out_file.write('] ;\n') - - # Specify graph & edge aesthetics - if leaves_parallel: - out_file.write('graph [ranksep=equally, splines=polyline] ;\n') - if rounded: - out_file.write('edge [fontname=helvetica] ;\n') - if rotate: - out_file.write('rankdir=LR ;\n') - - # Now recurse the tree and add node & edge attributes - if isinstance(decision_tree, _tree.Tree): - recurse(decision_tree, 0, criterion="impurity") - else: - recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion) - - # If required, draw leaf nodes at same depth as each other - if leaves_parallel: - for rank in sorted(ranks): - out_file.write("{rank=same ; " + - "; ".join(r for r in ranks[rank]) + "} ;\n") - out_file.write("}") + exporter = _DOTTreeExporter( + out_file=out_file, max_depth=max_depth, + feature_names=feature_names, class_names=class_names, label=label, + filled=filled, leaves_parallel=leaves_parallel, impurity=impurity, + node_ids=node_ids, proportion=proportion, rotate=rotate, + rounded=rounded, special_characters=special_characters, + precision=precision) + exporter.export(decision_tree) if return_string: - return out_file.getvalue() + return exporter.out_file.getvalue() finally: if own_file: From 0d5e3e267d7153c06dabcfc1601f682db3b18f39 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 29 Jun 2017 17:43:29 -0400 Subject: [PATCH 04/61] add class for mlp export --- sklearn/tree/export.py | 308 +++++++++++++++++------------------------ 1 file changed, 125 insertions(+), 183 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 9678ef8a18e88..9f9cf65cabe9e 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -74,161 +74,23 @@ def __repr__(): SENTINEL = Sentinel() -def node_to_str(tree, node_id, criterion): - # stupid copy & paste with a few adjustments - label = 'all' - feature_names = None - class_names = None - label = 'all' - impurity = True - node_ids = False - proportion = False - special_characters = False - precision = 3 - # Generate the node content string - if tree.n_outputs == 1: - value = tree.value[node_id][0, :] - else: - value = tree.value[node_id] - # Should labels be shown? - labels = (label == 'root' and node_id == 0) or label == 'all' - - # PostScript compatibility for special characters - if special_characters: - characters = ['#', '', '', '≤', '
', '>'] - node_string = '<' - else: - characters = ['#', '[', ']', '<=', '\n', ''] - node_string = '' - - # Write node ID - if node_ids: - if labels: - node_string += 'node ' - node_string += characters[0] + str(node_id) + characters[4] - - # Write decision criteria - if tree.children_left[node_id] != _tree.TREE_LEAF: - # Always write node decision criteria, except for leaves - if feature_names is not None: - feature = feature_names[tree.feature[node_id]] - else: - feature = "X%s%s%s" % (characters[1], - tree.feature[node_id], - characters[2]) - node_string += '%s %s %s%s' % (feature, - characters[3], - round(tree.threshold[node_id], - precision), - characters[4]) - - # Write impurity - if impurity: - if isinstance(criterion, _criterion.FriedmanMSE): - criterion = "friedman_mse" - elif not isinstance(criterion, six.string_types): - criterion = "impurity" - if labels: - node_string += '%s = ' % criterion - node_string += (str(round(tree.impurity[node_id], precision)) + - characters[4]) - - # Write node sample count - if labels: - node_string += 'samples = ' - if proportion: - percent = (100. * tree.n_node_samples[node_id] / - float(tree.n_node_samples[0])) - node_string += (str(round(percent, 1)) + '%' + - characters[4]) - else: - node_string += (str(tree.n_node_samples[node_id]) + - characters[4]) - - # Write node class distribution / regression value - if proportion and tree.n_classes[0] != 1: - # For classification this will show the proportion of samples - value = value / tree.weighted_n_node_samples[node_id] - if labels: - node_string += 'value = ' - if tree.n_classes[0] == 1: - # Regression - value_text = np.around(value, precision) - elif proportion: - # Classification - value_text = np.around(value, precision) - elif np.all(np.equal(np.mod(value, 1), 0)): - # Classification without floating-point weights - value_text = value.astype(int) - else: - # Classification with floating-point weights - value_text = np.around(value, precision) - # Strip whitespace - value_text = str(value_text.astype('S32')).replace("b'", "'") - value_text = value_text.replace("' '", ", ").replace("'", "") - if tree.n_classes[0] == 1 and tree.n_outputs == 1: - value_text = value_text.replace("[", "").replace("]", "") - value_text = value_text.replace("\n ", characters[4]) - node_string += value_text + characters[4] - - # Write node majority class - if (class_names is not None and - tree.n_classes[0] != 1 and - tree.n_outputs == 1): - # Only done for single-output classification trees - if labels: - node_string += 'class = ' - if class_names is not True: - class_name = class_names[np.argmax(value)] - else: - class_name = "y%s%s%s" % (characters[1], - np.argmax(value), - characters[2]) - node_string += class_name - - # Clean up any trailing newlines - if node_string[-2:] == '\n': - node_string = node_string[:-2] - if node_string[-5:] == '
': - node_string = node_string[:-5] - - return node_string + characters[5] - - -def _make_tree(node_id, et): - # traverses _tree.Tree recursively, builds intermediate "Tree" object - name = node_to_str(et, 0, criterion='entropy') - if (et.children_left[node_id] != et.children_right[node_id]): - children = [_make_tree(et.children_left[node_id], et), _make_tree( - et.children_right[node_id], et)] - else: - return Tree(name) - return Tree(name, *children) - - -def plot_tree(estimator): +def plot_tree(decision_tree, max_depth=None, feature_names=None, + class_names=None, label='all', filled=False, + leaves_parallel=False, impurity=True, node_ids=False, + proportion=False, rotate=False, rounded=False, + special_characters=False, precision=3, ax=None, scale=110): import matplotlib.pyplot as plt - bbox_args = dict(boxstyle="round", fc="0.8") - arrow_args = dict(arrowstyle="-") - - def draw_nodes(node, scale=1, zorder=0): - # 2 - is a hack to for not creating empty space. FIXME - if node.parent is None: - plt.annotate(node.tree, (node.x * scale, (2 - node.y) * scale), - bbox=bbox_args, ha='center', va='bottom', - zorder=zorder, xycoords='axes points') - else: - plt.annotate(node.tree, (node.parent.x * scale, (2 - node.parent.y) - * scale), - (node.x * scale, (2 - node.y) * scale), - bbox=bbox_args, arrowprops=arrow_args, ha='center', - va='bottom', zorder=zorder, xycoords='axes points') - for child in node.children: - draw_nodes(child, scale=scale, zorder=zorder - 1) + if ax is None: + ax = plt.gca() - my_tree = _make_tree(0, estimator.tree_) - dt = buchheim(my_tree) - draw_nodes(dt, scale=110) + exporter = _MPLTreeExporter( + max_depth=max_depth, feature_names=feature_names, + class_names=class_names, label=label, filled=filled, + leaves_parallel=leaves_parallel, impurity=impurity, node_ids=node_ids, + proportion=proportion, rotate=rotate, rounded=rounded, + special_characters=special_characters, precision=precision, ax=ax, + scale=scale) + exporter.export(decision_tree) class _DOTTreeExporter(object): @@ -252,6 +114,13 @@ def __init__(self, out_file=SENTINEL, max_depth=None, self.special_characters = special_characters self.precision = precision + # PostScript compatibility for special characters + if special_characters: + self.characters = ['#', '', '', '≤', '
', + '>', '<'] + else: + self.characters = ['#', '[', ']', '<=', '\\n', '"', '"'] + # validate if isinstance(precision, Integral): if precision < 0: @@ -350,6 +219,31 @@ def get_color(self, value): return '#' + ''.join(color) + def get_fill_color(self, tree, node_id): + # Fetch appropriate color for node + if 'rgb' not in self.colors: + # Initialize colors and bounds if required + self.colors['rgb'] = _color_brew(tree.n_classes[0]) + if tree.n_outputs != 1: + # Find max and min impurities for multi-output + self.colors['bounds'] = (np.min(-tree.impurity), + np.max(-tree.impurity)) + elif (tree.n_classes[0] == 1 and + len(np.unique(tree.value)) != 1): + # Find max and min values in leaf nodes for regression + self.colors['bounds'] = (np.min(tree.value), + np.max(tree.value)) + if tree.n_outputs == 1: + node_val = (tree.value[node_id][0, :] / + tree.weighted_n_node_samples[node_id]) + if tree.n_classes[0] == 1: + # Regression + node_val = tree.value[node_id][0, :] + else: + # If multi-output color node by impurity + node_val = -tree.impurity[node_id] + return self.get_color(node_val) + def node_to_str(self, tree, node_id, criterion): # Generate the node content string if tree.n_outputs == 1: @@ -360,13 +254,8 @@ def node_to_str(self, tree, node_id, criterion): # Should labels be shown? labels = (self.label == 'root' and node_id == 0) or self.label == 'all' - # PostScript compatibility for special characters - if self.special_characters: - characters = ['#', '', '', '≤', '
', '>'] - node_string = '<' - else: - characters = ['#', '[', ']', '<=', '\\n', '"'] - node_string = '"' + characters = self.characters + node_string = characters[-1] # Write node ID if self.node_ids: @@ -484,30 +373,8 @@ def recurse(self, tree, node_id, criterion, parent=None, depth=0): criterion))) if self.filled: - # Fetch appropriate color for node - if 'rgb' not in self.colors: - # Initialize colors and bounds if required - self.colors['rgb'] = _color_brew(tree.n_classes[0]) - if tree.n_outputs != 1: - # Find max and min impurities for multi-output - self.colors['bounds'] = (np.min(-tree.impurity), - np.max(-tree.impurity)) - elif (tree.n_classes[0] == 1 and - len(np.unique(tree.value)) != 1): - # Find max and min values in leaf nodes for regression - self.colors['bounds'] = (np.min(tree.value), - np.max(tree.value)) - if tree.n_outputs == 1: - node_val = (tree.value[node_id][0, :] / - tree.weighted_n_node_samples[node_id]) - if tree.n_classes[0] == 1: - # Regression - node_val = tree.value[node_id][0, :] - else: - # If multi-output color node by impurity - node_val = -tree.impurity[node_id] self.out_file.write(', fillcolor="%s"' - % self.get_color(node_val)) + % self.get_fill_color(tree, node_id)) self.out_file.write('] ;\n') if parent is not None: @@ -545,6 +412,81 @@ def recurse(self, tree, node_id, criterion, parent=None, depth=0): self.out_file.write('%d -> %d ;\n' % (parent, node_id)) +class _MPLTreeExporter(_DOTTreeExporter): + def __init__(self, ax, max_depth=None, feature_names=None, + class_names=None, label='all', filled=False, + leaves_parallel=False, impurity=True, node_ids=False, + proportion=False, rotate=False, rounded=False, + special_characters=False, precision=3, scale=110): + self.max_depth = max_depth + self.feature_names = feature_names + self.class_names = class_names + self.label = label + self.filled = filled + self.leaves_parallel = leaves_parallel + self.impurity = impurity + self.node_ids = node_ids + self.proportion = proportion + self.rotate = rotate + self.rounded = rounded + self.special_characters = special_characters + self.precision = precision + self.scale = scale + self.ax = ax + + # validate + if isinstance(precision, Integral): + if precision < 0: + raise ValueError("'precision' should be greater or equal to 0." + " Got {} instead.".format(precision)) + else: + raise ValueError("'precision' should be an integer. Got {}" + " instead.".format(type(precision))) + + # The depth of each node for plotting with 'leaf' option + self.ranks = {'leaves': []} + # The colors to render each node with + self.colors = {'bounds': None} + + self.characters = ['#', '[', ']', '<=', '\n', '', ''] + + self.bbox_args = dict(boxstyle="round", fc="0.8") + self.arrow_args = dict(arrowstyle="-") + + def _make_tree(self, node_id, et): + # traverses _tree.Tree recursively, builds intermediate "Tree" object + name = self.node_to_str(et, 0, criterion='entropy') + if (et.children_left[node_id] != et.children_right[node_id]): + children = [self._make_tree(et.children_left[node_id], et), + self._make_tree(et.children_right[node_id], et)] + else: + return Tree(name) + return Tree(name, *children) + + def export(self, decision_tree): + my_tree = self._make_tree(0, decision_tree.tree_) + dt = buchheim(my_tree) + self.recurse(dt) + + def recurse(self, node, zorder=0): + # 2 - is a hack to for not creating empty space. FIXME + if node.parent is None: + self.ax.annotate( + node.tree, (node.x * self.scale, (2 - node.y) * self.scale), + bbox=self.bbox_args, ha='center', va='bottom', zorder=zorder, + xycoords='axes points') + else: + self.ax.annotate( + node.tree, (node.parent.x * self.scale, (2 - node.parent.y) * + self.scale), (node.x * self.scale, (2 - node.y) * + self.scale), bbox=self.bbox_args, + arrowprops=self.arrow_args, ha='center', va='bottom', + zorder=zorder, + xycoords='axes points') + for child in node.children: + self.recurse(child, zorder=zorder - 1) + + def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, From 8f52d87ed320e068916bb725a79914be5270ec25 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 29 Jun 2017 18:14:43 -0400 Subject: [PATCH 05/61] add colors --- sklearn/tree/_reingold_tilford.py | 3 ++- sklearn/tree/export.py | 37 +++++++++++++++++-------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/sklearn/tree/_reingold_tilford.py b/sklearn/tree/_reingold_tilford.py index 4790a90dc6546..44472bef5b85b 100644 --- a/sklearn/tree/_reingold_tilford.py +++ b/sklearn/tree/_reingold_tilford.py @@ -175,9 +175,10 @@ def second_walk(v, m=0, depth=0, min=None): # my stuff class Tree(object): - def __init__(self, node="", *children): + def __init__(self, node="", node_id=-1, *children): self.node = node self.width = len(node) + self.node_id = node_id if children: self.children = children else: diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 9f9cf65cabe9e..c737d94dcc434 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -454,37 +454,40 @@ def __init__(self, ax, max_depth=None, feature_names=None, self.arrow_args = dict(arrowstyle="-") def _make_tree(self, node_id, et): - # traverses _tree.Tree recursively, builds intermediate "Tree" object + # traverses _tree.Tree recursively, builds intermediate + # "_reingold_tilford.Tree" object name = self.node_to_str(et, 0, criterion='entropy') if (et.children_left[node_id] != et.children_right[node_id]): children = [self._make_tree(et.children_left[node_id], et), self._make_tree(et.children_right[node_id], et)] else: - return Tree(name) - return Tree(name, *children) + return Tree(name, node_id) + return Tree(name, node_id, *children) def export(self, decision_tree): + self.ax.set_axis_off() my_tree = self._make_tree(0, decision_tree.tree_) dt = buchheim(my_tree) - self.recurse(dt) + self.recurse(dt, decision_tree.tree_) - def recurse(self, node, zorder=0): + def recurse(self, node, tree, zorder=0): # 2 - is a hack to for not creating empty space. FIXME + kwargs = dict(bbox=self.bbox_args, ha='center', va='bottom', + zorder=zorder, xycoords='axes points') + xy = (node.x * self.scale, (2 - node.y) * self.scale) + if self.filled: + kwargs['bbox']['fc'] = self.get_fill_color(tree, + node.tree.node_id) if node.parent is None: - self.ax.annotate( - node.tree, (node.x * self.scale, (2 - node.y) * self.scale), - bbox=self.bbox_args, ha='center', va='bottom', zorder=zorder, - xycoords='axes points') + # root + self.ax.annotate(node.tree, xy, **kwargs) else: - self.ax.annotate( - node.tree, (node.parent.x * self.scale, (2 - node.parent.y) * - self.scale), (node.x * self.scale, (2 - node.y) * - self.scale), bbox=self.bbox_args, - arrowprops=self.arrow_args, ha='center', va='bottom', - zorder=zorder, - xycoords='axes points') + xy_parent = (node.parent.x * self.scale, (2 - node.parent.y) * + self.scale) + kwargs["arrowprops"] = self.arrow_args + self.ax.annotate(node.tree, xy_parent, xy, **kwargs) for child in node.children: - self.recurse(child, zorder=zorder - 1) + self.recurse(child, tree, zorder=zorder - 1) def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None, From 4a5fe67013e01d3abe897492cc4163e816391b67 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 29 Jun 2017 18:29:06 -0400 Subject: [PATCH 06/61] separately scale x and y, add arrowheads, fix strings --- sklearn/tree/export.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index c737d94dcc434..00fde70911f79 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -78,7 +78,8 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3, ax=None, scale=110): + special_characters=False, precision=3, ax=None, scalex=150, + scaley=1): import matplotlib.pyplot as plt if ax is None: ax = plt.gca() @@ -89,7 +90,7 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, leaves_parallel=leaves_parallel, impurity=impurity, node_ids=node_ids, proportion=proportion, rotate=rotate, rounded=rounded, special_characters=special_characters, precision=precision, ax=ax, - scale=scale) + scalex=scalex, scaley=scaley) exporter.export(decision_tree) @@ -417,7 +418,7 @@ def __init__(self, ax, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3, scale=110): + special_characters=False, precision=3, scalex=150, scaley=1): self.max_depth = max_depth self.feature_names = feature_names self.class_names = class_names @@ -431,7 +432,8 @@ def __init__(self, ax, max_depth=None, feature_names=None, self.rounded = rounded self.special_characters = special_characters self.precision = precision - self.scale = scale + self.scalex = scalex + self.scaley = scaley self.ax = ax # validate @@ -450,13 +452,15 @@ def __init__(self, ax, max_depth=None, feature_names=None, self.characters = ['#', '[', ']', '<=', '\n', '', ''] - self.bbox_args = dict(boxstyle="round", fc="0.8") - self.arrow_args = dict(arrowstyle="-") + self.bbox_args = dict(fc='w') + if self.rounded: + self.bbox_args['boxstyle'] = "round" + self.arrow_args = dict(arrowstyle="<-") def _make_tree(self, node_id, et): # traverses _tree.Tree recursively, builds intermediate # "_reingold_tilford.Tree" object - name = self.node_to_str(et, 0, criterion='entropy') + name = self.node_to_str(et, node_id, criterion='entropy') if (et.children_left[node_id] != et.children_right[node_id]): children = [self._make_tree(et.children_left[node_id], et), self._make_tree(et.children_right[node_id], et)] @@ -470,24 +474,24 @@ def export(self, decision_tree): dt = buchheim(my_tree) self.recurse(dt, decision_tree.tree_) - def recurse(self, node, tree, zorder=0): + def recurse(self, node, tree, zorder=100): # 2 - is a hack to for not creating empty space. FIXME kwargs = dict(bbox=self.bbox_args, ha='center', va='bottom', zorder=zorder, xycoords='axes points') - xy = (node.x * self.scale, (2 - node.y) * self.scale) + xy = (node.x * self.scalex, (2 - node.y) * self.scaley) if self.filled: kwargs['bbox']['fc'] = self.get_fill_color(tree, node.tree.node_id) if node.parent is None: # root - self.ax.annotate(node.tree, xy, **kwargs) + self.ax.annotate(node.tree.node, xy, **kwargs) else: - xy_parent = (node.parent.x * self.scale, (2 - node.parent.y) * - self.scale) + xy_parent = (node.parent.x * self.scalex, (2 - node.parent.y) * + self.scaley) kwargs["arrowprops"] = self.arrow_args - self.ax.annotate(node.tree, xy_parent, xy, **kwargs) + self.ax.annotate(node.tree.node, xy_parent, xy, **kwargs) for child in node.children: - self.recurse(child, tree, zorder=zorder - 1) + self.recurse(child, tree, zorder=zorder - 10) def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None, From ddb6c1662af9fbba8d0907cec63290e3c53aa376 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 29 Jun 2017 19:07:04 -0400 Subject: [PATCH 07/61] implement max_depth --- sklearn/tree/export.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 00fde70911f79..2e84df3da7246 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -474,24 +474,33 @@ def export(self, decision_tree): dt = buchheim(my_tree) self.recurse(dt, decision_tree.tree_) - def recurse(self, node, tree, zorder=100): + def recurse(self, node, tree, depth=0): # 2 - is a hack to for not creating empty space. FIXME kwargs = dict(bbox=self.bbox_args, ha='center', va='bottom', - zorder=zorder, xycoords='axes points') + zorder=100 - 10 * depth, xycoords='axes points') xy = (node.x * self.scalex, (2 - node.y) * self.scaley) - if self.filled: - kwargs['bbox']['fc'] = self.get_fill_color(tree, - node.tree.node_id) - if node.parent is None: - # root - self.ax.annotate(node.tree.node, xy, **kwargs) + + if self.max_depth is None or depth <= self.max_depth: + if self.filled: + kwargs['bbox']['fc'] = self.get_fill_color(tree, + node.tree.node_id) + if node.parent is None: + # root + self.ax.annotate(node.tree.node, xy, **kwargs) + else: + xy_parent = (node.parent.x * self.scalex, (2 - node.parent.y) * + self.scaley) + kwargs["arrowprops"] = self.arrow_args + self.ax.annotate(node.tree.node, xy_parent, xy, **kwargs) + for child in node.children: + self.recurse(child, tree, depth=depth + 1) + else: xy_parent = (node.parent.x * self.scalex, (2 - node.parent.y) * self.scaley) kwargs["arrowprops"] = self.arrow_args - self.ax.annotate(node.tree.node, xy_parent, xy, **kwargs) - for child in node.children: - self.recurse(child, tree, zorder=zorder - 10) + kwargs['bbox']['fc'] = 'grey' + self.ax.annotate("\n (...) \n", xy_parent, xy, **kwargs) def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None, From fed2d1d9c48be68b14632d5fff93e95e21b3e723 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 29 Jun 2017 19:36:03 -0400 Subject: [PATCH 08/61] don't use alpha for coloring because it makes boxes transparent --- sklearn/tree/export.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 2e84df3da7246..545dbafa2ac19 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -202,18 +202,16 @@ def get_color(self, value): if len(sorted_values) == 1: alpha = 0 else: - alpha = int(np.round(255 * (sorted_values[0] - - sorted_values[1]) / - (1 - sorted_values[1]), 0)) + alpha = ((sorted_values[0] - sorted_values[1]) + / (1 - sorted_values[1])) else: # Regression tree or multi-output color = list(self.colors['rgb'][0]) - alpha = int(np.round(255 * ((value - self.colors['bounds'][0]) / - (self.colors['bounds'][1] - - self.colors['bounds'][0])), 0)) - - # Return html color code in #RRGGBBAA format - color.append(alpha) + alpha = ((value - self.colors['bounds'][0]) / + (self.colors['bounds'][1] - self.colors['bounds'][0])) + # compute the color as alpha against white + color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color] + # Return html color code in #RRGGBB format hex_codes = [str(i) for i in range(10)] hex_codes.extend(['a', 'b', 'c', 'd', 'e', 'f']) color = [hex_codes[c // 16] + hex_codes[c % 16] for c in color] From 5145ed29ef33186a8490d04e987398ae6f132924 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 29 Jun 2017 19:39:00 -0400 Subject: [PATCH 09/61] remove unused variables --- sklearn/tree/_reingold_tilford.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/tree/_reingold_tilford.py b/sklearn/tree/_reingold_tilford.py index 44472bef5b85b..60eb6f556fda7 100644 --- a/sklearn/tree/_reingold_tilford.py +++ b/sklearn/tree/_reingold_tilford.py @@ -78,8 +78,6 @@ def firstwalk(v, distance=1.): midpoint = (v.children[0].x + v.children[-1].x) / 2 - ell = v.children[0] - arr = v.children[-1] w = v.lbrother() if w: v.x = w.x + distance From 8663ad78590063a4093b7cf8e84b3449f7cf36a6 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 29 Jun 2017 19:41:38 -0400 Subject: [PATCH 10/61] vertical center of boxes --- sklearn/tree/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 545dbafa2ac19..b6c098d10d0fd 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -474,7 +474,7 @@ def export(self, decision_tree): def recurse(self, node, tree, depth=0): # 2 - is a hack to for not creating empty space. FIXME - kwargs = dict(bbox=self.bbox_args, ha='center', va='bottom', + kwargs = dict(bbox=self.bbox_args, ha='center', va='center', zorder=100 - 10 * depth, xycoords='axes points') xy = (node.x * self.scalex, (2 - node.y) * self.scaley) From d750deb8c880122feae28ae76e2489a4007702de Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 29 Jun 2017 19:48:34 -0400 Subject: [PATCH 11/61] fix/simplify newline trimming --- sklearn/tree/export.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index b6c098d10d0fd..5e8668349e4de 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -342,10 +342,8 @@ def node_to_str(self, tree, node_id, criterion): node_string += class_name # Clean up any trailing newlines - if node_string[-2:] == '\\n': - node_string = node_string[:-2] - if node_string[-5:] == '
': - node_string = node_string[:-5] + if node_string.endswith(characters[4]): + node_string = node_string[:-len(characters[4])] return node_string + characters[5] From d3c17eaa217d06a870a76c23098ed5e0efa0670f Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 30 Jun 2017 11:36:25 -0400 Subject: [PATCH 12/61] somewhere in the middle of stuff trying to get rid of scalex, scaley --- sklearn/tree/_reingold_tilford.py | 22 ---------------------- sklearn/tree/export.py | 15 +++++++++++---- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/sklearn/tree/_reingold_tilford.py b/sklearn/tree/_reingold_tilford.py index 60eb6f556fda7..9e710dc75bca7 100644 --- a/sklearn/tree/_reingold_tilford.py +++ b/sklearn/tree/_reingold_tilford.py @@ -170,8 +170,6 @@ def second_walk(v, m=0, depth=0, min=None): return min -# my stuff - class Tree(object): def __init__(self, node="", node_id=-1, *children): self.node = node @@ -181,23 +179,3 @@ def __init__(self, node="", node_id=-1, *children): self.children = children else: self.children = [] - - def __str__(self): - return "%s" % (self.node) - - def __repr__(self): - return "%s" % (self.node) - - def __getitem__(self, key): - if isinstance(key, int) or isinstance(key, slice): - return self.children[key] - if isinstance(key, str): - for child in self.children: - if child.node == key: - return child - - def __iter__(self): - return self.children.__iter__() - - def __len__(self): - return len(self.children) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 5e8668349e4de..34d5161263027 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -78,8 +78,7 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3, ax=None, scalex=150, - scaley=1): + special_characters=False, precision=3, ax=None, scalex=150): import matplotlib.pyplot as plt if ax is None: ax = plt.gca() @@ -90,7 +89,7 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, leaves_parallel=leaves_parallel, impurity=impurity, node_ids=node_ids, proportion=proportion, rotate=rotate, rounded=rounded, special_characters=special_characters, precision=precision, ax=ax, - scalex=scalex, scaley=scaley) + scalex=scalex) exporter.export(decision_tree) @@ -429,8 +428,8 @@ def __init__(self, ax, max_depth=None, feature_names=None, self.special_characters = special_characters self.precision = precision self.scalex = scalex - self.scaley = scaley self.ax = ax + self.scaley = 80 if class_names is None else 100 # validate if isinstance(precision, Integral): @@ -464,9 +463,17 @@ def _make_tree(self, node_id, et): return Tree(name, node_id) return Tree(name, node_id, *children) + def _find_longest(self, my_tree, max_length): + child_length = [_find_longest(c, max_length) for c + in my_cildren] + return max(child_length + [max_length]) + + def export(self, decision_tree): self.ax.set_axis_off() my_tree = self._make_tree(0, decision_tree.tree_) + # find longest string: + self._find_longest(my_tree, len(my_tree.node))) dt = buchheim(my_tree) self.recurse(dt, decision_tree.tree_) From 823ce1f0b046d5d1ab6ecbc43d5ba1bee8e10e5a Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 6 Jul 2017 17:37:07 -0400 Subject: [PATCH 13/61] remove "find_longest_child" for now, fix tests --- sklearn/tree/export.py | 13 +++++++------ sklearn/tree/tests/test_export.py | 26 +++++++++++++------------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 34d5161263027..76da293b30742 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -208,6 +208,8 @@ def get_color(self, value): color = list(self.colors['rgb'][0]) alpha = ((value - self.colors['bounds'][0]) / (self.colors['bounds'][1] - self.colors['bounds'][0])) + # unpack numpy scalars + alpha = float(alpha) # compute the color as alpha against white color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color] # Return html color code in #RRGGBB format @@ -463,17 +465,16 @@ def _make_tree(self, node_id, et): return Tree(name, node_id) return Tree(name, node_id, *children) - def _find_longest(self, my_tree, max_length): - child_length = [_find_longest(c, max_length) for c - in my_cildren] - return max(child_length + [max_length]) - + # def _find_longest(self, my_tree, max_length): + # child_length = [_find_longest(c, max_length) for c + # in my_cildren] + # return max(child_length + [max_length]) def export(self, decision_tree): self.ax.set_axis_off() my_tree = self._make_tree(0, decision_tree.tree_) # find longest string: - self._find_longest(my_tree, len(my_tree.node))) + # self._find_longest(my_tree, len(my_tree.node))) dt = buchheim(my_tree) self.recurse(dt, decision_tree.tree_) diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 0bf70073d34c7..ed119b27f632a 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -92,13 +92,13 @@ def test_graphviz_toy(): 'fontname=helvetica] ;\n' \ 'edge [fontname=helvetica] ;\n' \ '0 [label=0 ≤ 0.0
samples = 100.0%
' \ - 'value = [0.5, 0.5]>, fillcolor="#e5813900"] ;\n' \ + 'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n' \ '1 [label=value = [1.0, 0.0]>, ' \ - 'fillcolor="#e58139ff"] ;\n' \ + 'fillcolor="#e58139"] ;\n' \ '0 -> 1 [labeldistance=2.5, labelangle=45, ' \ 'headlabel="True"] ;\n' \ '2 [label=value = [0.0, 1.0]>, ' \ - 'fillcolor="#399de5ff"] ;\n' \ + 'fillcolor="#399de5"] ;\n' \ '0 -> 2 [labeldistance=2.5, labelangle=-45, ' \ 'headlabel="False"] ;\n' \ '}' @@ -126,7 +126,7 @@ def test_graphviz_toy(): contents2 = 'digraph Tree {\n' \ 'node [shape=box, style="filled", color="black"] ;\n' \ '0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n' \ - 'samples = 6\\nvalue = [3, 3]", fillcolor="#e5813900"] ;\n' \ + 'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n' \ '1 [label="(...)", fillcolor="#C0C0C0"] ;\n' \ '0 -> 1 ;\n' \ '2 [label="(...)", fillcolor="#C0C0C0"] ;\n' \ @@ -148,21 +148,21 @@ def test_graphviz_toy(): 'node [shape=box, style="filled", color="black"] ;\n' \ '0 [label="X[0] <= 0.0\\nsamples = 6\\n' \ 'value = [[3.0, 1.5, 0.0]\\n' \ - '[3.0, 1.0, 0.5]]", fillcolor="#e5813900"] ;\n' \ + '[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n' \ '1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n' \ - '[3, 0, 0]]", fillcolor="#e58139ff"] ;\n' \ + '[3, 0, 0]]", fillcolor="#e58139"] ;\n' \ '0 -> 1 [labeldistance=2.5, labelangle=45, ' \ 'headlabel="True"] ;\n' \ '2 [label="X[0] <= 1.5\\nsamples = 3\\n' \ 'value = [[0.0, 1.5, 0.0]\\n' \ - '[0.0, 1.0, 0.5]]", fillcolor="#e5813986"] ;\n' \ + '[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n' \ '0 -> 2 [labeldistance=2.5, labelangle=-45, ' \ 'headlabel="False"] ;\n' \ '3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n' \ - '[0, 1, 0]]", fillcolor="#e58139ff"] ;\n' \ + '[0, 1, 0]]", fillcolor="#e58139"] ;\n' \ '2 -> 3 ;\n' \ '4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n' \ - '[0.0, 0.0, 0.5]]", fillcolor="#e58139ff"] ;\n' \ + '[0.0, 0.0, 0.5]]", fillcolor="#e58139"] ;\n' \ '2 -> 4 ;\n' \ '}' @@ -184,13 +184,13 @@ def test_graphviz_toy(): 'edge [fontname=helvetica] ;\n' \ 'rankdir=LR ;\n' \ '0 [label="X[0] <= 0.0\\nmse = 1.0\\nsamples = 6\\n' \ - 'value = 0.0", fillcolor="#e5813980"] ;\n' \ + 'value = 0.0", fillcolor="#f2c09c"] ;\n' \ '1 [label="mse = 0.0\\nsamples = 3\\nvalue = -1.0", ' \ - 'fillcolor="#e5813900"] ;\n' \ + 'fillcolor="#ffffff"] ;\n' \ '0 -> 1 [labeldistance=2.5, labelangle=-45, ' \ 'headlabel="True"] ;\n' \ '2 [label="mse = 0.0\\nsamples = 3\\nvalue = 1.0", ' \ - 'fillcolor="#e58139ff"] ;\n' \ + 'fillcolor="#e58139"] ;\n' \ '0 -> 2 [labeldistance=2.5, labelangle=45, ' \ 'headlabel="False"] ;\n' \ '{rank=same ; 0} ;\n' \ @@ -207,7 +207,7 @@ def test_graphviz_toy(): contents2 = 'digraph Tree {\n' \ 'node [shape=box, style="filled", color="black"] ;\n' \ '0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", ' \ - 'fillcolor="#e5813900"] ;\n' \ + 'fillcolor="#ffffff"] ;\n' \ '}' assert_equal(contents1, contents2) From 0229d5dbc621f86df58d3e4cf8795d0670dd5304 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 6 Jul 2017 18:33:36 -0400 Subject: [PATCH 14/61] make scalex and scaley internal, and ax local. render everything once to get the bbox sizes, then again to actually plot it with known extents. --- sklearn/tree/export.py | 68 ++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 76da293b30742..3786d6c75f32a 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -78,19 +78,14 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3, ax=None, scalex=150): - import matplotlib.pyplot as plt - if ax is None: - ax = plt.gca() - + special_characters=False, precision=3, ax=None): exporter = _MPLTreeExporter( max_depth=max_depth, feature_names=feature_names, class_names=class_names, label=label, filled=filled, leaves_parallel=leaves_parallel, impurity=impurity, node_ids=node_ids, proportion=proportion, rotate=rotate, rounded=rounded, - special_characters=special_characters, precision=precision, ax=ax, - scalex=scalex) - exporter.export(decision_tree) + special_characters=special_characters, precision=precision) + exporter.export(decision_tree, ax=ax) class _DOTTreeExporter(object): @@ -411,11 +406,11 @@ def recurse(self, tree, node_id, criterion, parent=None, depth=0): class _MPLTreeExporter(_DOTTreeExporter): - def __init__(self, ax, max_depth=None, feature_names=None, + def __init__(self, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3, scalex=150, scaley=1): + special_characters=False, precision=3): self.max_depth = max_depth self.feature_names = feature_names self.class_names = class_names @@ -429,9 +424,7 @@ def __init__(self, ax, max_depth=None, feature_names=None, self.rounded = rounded self.special_characters = special_characters self.precision = precision - self.scalex = scalex - self.ax = ax - self.scaley = 80 if class_names is None else 100 + self._scaley = 80 if class_names is None else 100 # validate if isinstance(precision, Integral): @@ -465,24 +458,33 @@ def _make_tree(self, node_id, et): return Tree(name, node_id) return Tree(name, node_id, *children) - # def _find_longest(self, my_tree, max_length): - # child_length = [_find_longest(c, max_length) for c - # in my_cildren] - # return max(child_length + [max_length]) - - def export(self, decision_tree): - self.ax.set_axis_off() + def export(self, decision_tree, ax=None): + import matplotlib.pyplot as plt + from matplotlib.text import Annotation + if ax is None: + ax = plt.gca() + ax.set_axis_off() my_tree = self._make_tree(0, decision_tree.tree_) # find longest string: - # self._find_longest(my_tree, len(my_tree.node))) + # print(my_tree._longest_str()) dt = buchheim(my_tree) - self.recurse(dt, decision_tree.tree_) - - def recurse(self, node, tree, depth=0): + # plot once with phantom axis to get sizes: + ax_phantom = plt.gcf().add_axes((0, 0, 1, 1)) + self._scalex = 1 + self.recurse(dt, decision_tree.tree_, ax_phantom) + plt.draw() + bbox_widths = [ann.get_bbox_patch().get_width() for ann in + ax_phantom.get_children() + if isinstance(ann, Annotation)] + self._scalex = max(bbox_widths) + ax_phantom.set_visible(False) + self.recurse(dt, decision_tree.tree_, ax) + + def recurse(self, node, tree, ax, depth=0): # 2 - is a hack to for not creating empty space. FIXME kwargs = dict(bbox=self.bbox_args, ha='center', va='center', zorder=100 - 10 * depth, xycoords='axes points') - xy = (node.x * self.scalex, (2 - node.y) * self.scaley) + xy = (node.x * self._scalex, (2 - node.y) * self._scaley) if self.max_depth is None or depth <= self.max_depth: if self.filled: @@ -490,21 +492,21 @@ def recurse(self, node, tree, depth=0): node.tree.node_id) if node.parent is None: # root - self.ax.annotate(node.tree.node, xy, **kwargs) + ax.annotate(node.tree.node, xy, **kwargs) else: - xy_parent = (node.parent.x * self.scalex, (2 - node.parent.y) * - self.scaley) + xy_parent = (node.parent.x * self._scalex, + (2 - node.parent.y) * self._scaley) kwargs["arrowprops"] = self.arrow_args - self.ax.annotate(node.tree.node, xy_parent, xy, **kwargs) + ax.annotate(node.tree.node, xy_parent, xy, **kwargs) for child in node.children: - self.recurse(child, tree, depth=depth + 1) + self.recurse(child, tree, ax, depth=depth + 1) else: - xy_parent = (node.parent.x * self.scalex, (2 - node.parent.y) * - self.scaley) + xy_parent = (node.parent.x * self._scalex, (2 - node.parent.y) * + self._scaley) kwargs["arrowprops"] = self.arrow_args kwargs['bbox']['fc'] = 'grey' - self.ax.annotate("\n (...) \n", xy_parent, xy, **kwargs) + ax.annotate("\n (...) \n", xy_parent, xy, **kwargs) def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None, From a2df69ed3bab5f0f5867915615738764d21bb298 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 6 Jul 2017 18:36:21 -0400 Subject: [PATCH 15/61] add some margin to the max bbox width --- sklearn/tree/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 3786d6c75f32a..b8d0cf35359c5 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -476,7 +476,7 @@ def export(self, decision_tree, ax=None): bbox_widths = [ann.get_bbox_patch().get_width() for ann in ax_phantom.get_children() if isinstance(ann, Annotation)] - self._scalex = max(bbox_widths) + self._scalex = max(bbox_widths) + 2 # some margin ax_phantom.set_visible(False) self.recurse(dt, decision_tree.tree_, ax) From 5212f5942bd8e7f5c1940eb6b208defeb57658af Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 7 Jul 2017 11:41:02 -0400 Subject: [PATCH 16/61] add _BaseTreeExporter baseclass --- sklearn/tree/export.py | 202 +++++++++++++++++++++-------------------- 1 file changed, 102 insertions(+), 100 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index b8d0cf35359c5..1f23a585112b7 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -88,105 +88,7 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, exporter.export(decision_tree, ax=ax) -class _DOTTreeExporter(object): - def __init__(self, out_file=SENTINEL, max_depth=None, - feature_names=None, class_names=None, label='all', - filled=False, leaves_parallel=False, impurity=True, - node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3): - self.out_file = out_file - self.max_depth = max_depth - self.feature_names = feature_names - self.class_names = class_names - self.label = label - self.filled = filled - self.leaves_parallel = leaves_parallel - self.impurity = impurity - self.node_ids = node_ids - self.proportion = proportion - self.rotate = rotate - self.rounded = rounded - self.special_characters = special_characters - self.precision = precision - - # PostScript compatibility for special characters - if special_characters: - self.characters = ['#', '', '', '≤', '
', - '>', '<'] - else: - self.characters = ['#', '[', ']', '<=', '\\n', '"', '"'] - - # validate - if isinstance(precision, Integral): - if precision < 0: - raise ValueError("'precision' should be greater or equal to 0." - " Got {} instead.".format(precision)) - else: - raise ValueError("'precision' should be an integer. Got {}" - " instead.".format(type(precision))) - - # The depth of each node for plotting with 'leaf' option - self.ranks = {'leaves': []} - # The colors to render each node with - self.colors = {'bounds': None} - - def export(self, decision_tree): - # Check length of feature_names before getting into the tree node - # Raise error if length of feature_names does not match - # n_features_ in the decision_tree - if self.feature_names is not None: - if len(self.feature_names) != decision_tree.n_features_: - raise ValueError("Length of feature_names, %d " - "does not match number of features, %d" - % (len(self.feature_names), - decision_tree.n_features_)) - # each part writes to out_file - self.head() - # Now recurse the tree and add node & edge attributes - if isinstance(decision_tree, _tree.Tree): - self.recurse(decision_tree, 0, criterion="impurity") - else: - self.recurse(decision_tree.tree_, 0, - criterion=decision_tree.criterion) - - self.tail() - - def tail(self): - # If required, draw leaf nodes at same depth as each other - if self.leaves_parallel: - for rank in sorted(self.ranks): - self.out_file.write( - "{rank=same ; " + - "; ".join(r for r in self.ranks[rank]) + "} ;\n") - self.out_file.write("}") - - def head(self): - self.out_file.write('digraph Tree {\n') - - # Specify node aesthetics - self.out_file.write('node [shape=box') - rounded_filled = [] - if self.filled: - rounded_filled.append('filled') - if self.rounded: - rounded_filled.append('rounded') - if len(rounded_filled) > 0: - self.out_file.write( - ', style="%s", color="black"' - % ", ".join(rounded_filled)) - if self.rounded: - self.out_file.write(', fontname=helvetica') - self.out_file.write('] ;\n') - - # Specify graph & edge aesthetics - if self.leaves_parallel: - self.out_file.write( - 'graph [ranksep=equally, splines=polyline] ;\n') - if self.rounded: - self.out_file.write('edge [fontname=helvetica] ;\n') - if self.rotate: - self.out_file.write('rankdir=LR ;\n') - +class _BaseTreeExporter(object): def get_color(self, value): # Find the appropriate color & intensity for a node if self.colors['bounds'] is None: @@ -343,6 +245,106 @@ def node_to_str(self, tree, node_id, criterion): return node_string + characters[5] + +class _DOTTreeExporter(_BaseTreeExporter): + def __init__(self, out_file=SENTINEL, max_depth=None, + feature_names=None, class_names=None, label='all', + filled=False, leaves_parallel=False, impurity=True, + node_ids=False, proportion=False, rotate=False, rounded=False, + special_characters=False, precision=3): + self.out_file = out_file + self.max_depth = max_depth + self.feature_names = feature_names + self.class_names = class_names + self.label = label + self.filled = filled + self.leaves_parallel = leaves_parallel + self.impurity = impurity + self.node_ids = node_ids + self.proportion = proportion + self.rotate = rotate + self.rounded = rounded + self.special_characters = special_characters + self.precision = precision + + # PostScript compatibility for special characters + if special_characters: + self.characters = ['#', '', '', '≤', '
', + '>', '<'] + else: + self.characters = ['#', '[', ']', '<=', '\\n', '"', '"'] + + # validate + if isinstance(precision, Integral): + if precision < 0: + raise ValueError("'precision' should be greater or equal to 0." + " Got {} instead.".format(precision)) + else: + raise ValueError("'precision' should be an integer. Got {}" + " instead.".format(type(precision))) + + # The depth of each node for plotting with 'leaf' option + self.ranks = {'leaves': []} + # The colors to render each node with + self.colors = {'bounds': None} + + def export(self, decision_tree): + # Check length of feature_names before getting into the tree node + # Raise error if length of feature_names does not match + # n_features_ in the decision_tree + if self.feature_names is not None: + if len(self.feature_names) != decision_tree.n_features_: + raise ValueError("Length of feature_names, %d " + "does not match number of features, %d" + % (len(self.feature_names), + decision_tree.n_features_)) + # each part writes to out_file + self.head() + # Now recurse the tree and add node & edge attributes + if isinstance(decision_tree, _tree.Tree): + self.recurse(decision_tree, 0, criterion="impurity") + else: + self.recurse(decision_tree.tree_, 0, + criterion=decision_tree.criterion) + + self.tail() + + def tail(self): + # If required, draw leaf nodes at same depth as each other + if self.leaves_parallel: + for rank in sorted(self.ranks): + self.out_file.write( + "{rank=same ; " + + "; ".join(r for r in self.ranks[rank]) + "} ;\n") + self.out_file.write("}") + + def head(self): + self.out_file.write('digraph Tree {\n') + + # Specify node aesthetics + self.out_file.write('node [shape=box') + rounded_filled = [] + if self.filled: + rounded_filled.append('filled') + if self.rounded: + rounded_filled.append('rounded') + if len(rounded_filled) > 0: + self.out_file.write( + ', style="%s", color="black"' + % ", ".join(rounded_filled)) + if self.rounded: + self.out_file.write(', fontname=helvetica') + self.out_file.write('] ;\n') + + # Specify graph & edge aesthetics + if self.leaves_parallel: + self.out_file.write( + 'graph [ranksep=equally, splines=polyline] ;\n') + if self.rounded: + self.out_file.write('edge [fontname=helvetica] ;\n') + if self.rotate: + self.out_file.write('rankdir=LR ;\n') + def recurse(self, tree, node_id, criterion, parent=None, depth=0): if node_id == _tree.TREE_LEAF: raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF) @@ -405,7 +407,7 @@ def recurse(self, tree, node_id, criterion, parent=None, depth=0): self.out_file.write('%d -> %d ;\n' % (parent, node_id)) -class _MPLTreeExporter(_DOTTreeExporter): +class _MPLTreeExporter(_BaseTreeExporter): def __init__(self, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, From 60c0b73edb9b5fe3244d5844316b85e47b37774c Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 7 Jul 2017 11:48:03 -0400 Subject: [PATCH 17/61] add docstring to plot_tree --- sklearn/tree/export.py | 77 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 1f23a585112b7..6304adc0aac9c 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -79,6 +79,83 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, special_characters=False, precision=3, ax=None): + """Plot a decision tree. + + The sample counts that are shown are weighted with any sample_weights that + might be present. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + decision_tree : decision tree classifier + The decision tree to be exported to GraphViz. + + max_depth : int, optional (default=None) + The maximum depth of the representation. If None, the tree is fully + generated. + + feature_names : list of strings, optional (default=None) + Names of each of the features. + + class_names : list of strings, bool or None, optional (default=None) + Names of each of the target classes in ascending numerical order. + Only relevant for classification and not supported for multi-output. + If ``True``, shows a symbolic representation of the class name. + + label : {'all', 'root', 'none'}, optional (default='all') + Whether to show informative labels for impurity, etc. + Options include 'all' to show at every node, 'root' to show only at + the top root node, or 'none' to not show at any node. + + filled : bool, optional (default=False) + When set to ``True``, paint nodes to indicate majority class for + classification, extremity of values for regression, or purity of node + for multi-output. + + leaves_parallel : bool, optional (default=False) + When set to ``True``, draw all leaf nodes at the bottom of the tree. + + impurity : bool, optional (default=True) + When set to ``True``, show the impurity at each node. + + node_ids : bool, optional (default=False) + When set to ``True``, show the ID number on each node. + + proportion : bool, optional (default=False) + When set to ``True``, change the display of 'values' and/or 'samples' + to be proportions and percentages respectively. + + rotate : bool, optional (default=False) + When set to ``True``, orient tree left to right rather than top-down. + + rounded : bool, optional (default=False) + When set to ``True``, draw node boxes with rounded corners and use + Helvetica fonts instead of Times-Roman. + + special_characters : bool, optional (default=False) + When set to ``False``, ignore special characters for PostScript + compatibility. + + precision : int, optional (default=3) + Number of digits of precision for floating point in the values of + impurity, threshold and value attributes of each node. + + ax : matplotlib axis, optional (default=None) + Axes to plot to. If None, use current axis. + + Examples + -------- + >>> from sklearn.datasets import load_iris + >>> from sklearn import tree + + >>> clf = tree.DecisionTreeClassifier() + >>> iris = load_iris() + + >>> clf = clf.fit(iris.data, iris.target) + >>> tree.plot_tree(clf) # doctest: +SKIP + + """ exporter = _MPLTreeExporter( max_depth=max_depth, feature_names=feature_names, class_names=class_names, label=label, filled=filled, From 3b4a730382b7297d87af143330ffae871241351e Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 7 Jul 2017 12:54:55 -0400 Subject: [PATCH 18/61] use data coordinates so we can put the plot in a subplot, remove some hacks. --- sklearn/tree/export.py | 46 ++++++++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 6304adc0aac9c..13855f4624949 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -78,7 +78,7 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3, ax=None): + special_characters=False, precision=3, ax=None, fontsize=None): """Plot a decision tree. The sample counts that are shown are weighted with any sample_weights that @@ -161,7 +161,8 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, class_names=class_names, label=label, filled=filled, leaves_parallel=leaves_parallel, impurity=impurity, node_ids=node_ids, proportion=proportion, rotate=rotate, rounded=rounded, - special_characters=special_characters, precision=precision) + special_characters=special_characters, precision=precision, + fontsize=fontsize) exporter.export(decision_tree, ax=ax) @@ -489,7 +490,7 @@ def __init__(self, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3): + special_characters=False, precision=3, fontsize=None): self.max_depth = max_depth self.feature_names = feature_names self.class_names = class_names @@ -503,7 +504,8 @@ def __init__(self, max_depth=None, feature_names=None, self.rounded = rounded self.special_characters = special_characters self.precision = precision - self._scaley = 80 if class_names is None else 100 + self.fontsize = fontsize + self._scaley = 1 # validate if isinstance(precision, Integral): @@ -524,6 +526,7 @@ def __init__(self, max_depth=None, feature_names=None, self.bbox_args = dict(fc='w') if self.rounded: self.bbox_args['boxstyle'] = "round" + self.arrow_args = dict(arrowstyle="<-") def _make_tree(self, node_id, et): @@ -544,26 +547,29 @@ def export(self, decision_tree, ax=None): ax = plt.gca() ax.set_axis_off() my_tree = self._make_tree(0, decision_tree.tree_) - # find longest string: - # print(my_tree._longest_str()) dt = buchheim(my_tree) - # plot once with phantom axis to get sizes: - ax_phantom = plt.gcf().add_axes((0, 0, 1, 1)) self._scalex = 1 - self.recurse(dt, decision_tree.tree_, ax_phantom) - plt.draw() - bbox_widths = [ann.get_bbox_patch().get_width() for ann in - ax_phantom.get_children() - if isinstance(ann, Annotation)] - self._scalex = max(bbox_widths) + 2 # some margin - ax_phantom.set_visible(False) self.recurse(dt, decision_tree.tree_, ax) + # get all the annotated points + xys = [ann.xyann for ann in ax.get_children() + if isinstance(ann, Annotation)] + + # set axis limits with slight margin of .5 + mins = np.min(xys, axis=0) - .5 + maxs = np.max(xys, axis=0) + .5 + + ax.set_xlim(mins[0], maxs[0]) + ax.set_ylim(maxs[1], mins[1]) + def recurse(self, node, tree, ax, depth=0): - # 2 - is a hack to for not creating empty space. FIXME kwargs = dict(bbox=self.bbox_args, ha='center', va='center', - zorder=100 - 10 * depth, xycoords='axes points') - xy = (node.x * self._scalex, (2 - node.y) * self._scaley) + zorder=100 - 10 * depth) + + if self.fontsize is not None: + kwargs['fontsize'] = self.fontsize + + xy = (node.x * self._scalex, node.y * self._scaley) if self.max_depth is None or depth <= self.max_depth: if self.filled: @@ -574,14 +580,14 @@ def recurse(self, node, tree, ax, depth=0): ax.annotate(node.tree.node, xy, **kwargs) else: xy_parent = (node.parent.x * self._scalex, - (2 - node.parent.y) * self._scaley) + node.parent.y * self._scaley) kwargs["arrowprops"] = self.arrow_args ax.annotate(node.tree.node, xy_parent, xy, **kwargs) for child in node.children: self.recurse(child, tree, ax, depth=depth + 1) else: - xy_parent = (node.parent.x * self._scalex, (2 - node.parent.y) * + xy_parent = (node.parent.x * self._scalex, node.parent.y * self._scaley) kwargs["arrowprops"] = self.arrow_args kwargs['bbox']['fc'] = 'grey' From a30f63418a80d1916ff136dea69bf894cb73a4e6 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 10 Jul 2017 12:26:40 -0400 Subject: [PATCH 19/61] remove scalex, scaley, add automatic font size --- sklearn/tree/export.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 13855f4624949..3fa0081cff3be 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -505,7 +505,6 @@ def __init__(self, max_depth=None, feature_names=None, self.special_characters = special_characters self.precision = precision self.fontsize = fontsize - self._scaley = 1 # validate if isinstance(precision, Integral): @@ -548,20 +547,40 @@ def export(self, decision_tree, ax=None): ax.set_axis_off() my_tree = self._make_tree(0, decision_tree.tree_) dt = buchheim(my_tree) - self._scalex = 1 self.recurse(dt, decision_tree.tree_, ax) + anns = [ann for ann in ax.get_children() + if isinstance(ann, Annotation)] + # get all the annotated points - xys = [ann.xyann for ann in ax.get_children() - if isinstance(ann, Annotation)] + xys = [ann.xyann for ann in anns] # set axis limits with slight margin of .5 - mins = np.min(xys, axis=0) - .5 + mins = np.min(xys, axis=0) maxs = np.max(xys, axis=0) + .5 ax.set_xlim(mins[0], maxs[0]) ax.set_ylim(maxs[1], mins[1]) + if self.fontsize is None: + # get figure to data transform + inv = ax.transData.inverted() + renderer = ax.figure.canvas.get_renderer() + # update sizes of all bboxes + for ann in anns: + ann.update_bbox_position_size(renderer) + # get max box width + widths = [inv.get_matrix()[0, 0] + * ann.get_bbox_patch().get_window_extent().width + for ann in anns] + # get minimum max size to not be too big. + max_width = max(max(widths), 1) + # adjust fontsize to avoid overlap + # width should be around 1 in data coordinates + size = anns[0].get_fontsize() / max_width + for ann in anns: + ann.set_fontsize(size) + def recurse(self, node, tree, ax, depth=0): kwargs = dict(bbox=self.bbox_args, ha='center', va='center', zorder=100 - 10 * depth) @@ -569,7 +588,7 @@ def recurse(self, node, tree, ax, depth=0): if self.fontsize is not None: kwargs['fontsize'] = self.fontsize - xy = (node.x * self._scalex, node.y * self._scaley) + xy = (node.x, node.y) if self.max_depth is None or depth <= self.max_depth: if self.filled: @@ -579,16 +598,14 @@ def recurse(self, node, tree, ax, depth=0): # root ax.annotate(node.tree.node, xy, **kwargs) else: - xy_parent = (node.parent.x * self._scalex, - node.parent.y * self._scaley) + xy_parent = (node.parent.x, node.parent.y) kwargs["arrowprops"] = self.arrow_args ax.annotate(node.tree.node, xy_parent, xy, **kwargs) for child in node.children: self.recurse(child, tree, ax, depth=depth + 1) else: - xy_parent = (node.parent.x * self._scalex, node.parent.y * - self._scaley) + xy_parent = (node.parent.x, node.parent.y) kwargs["arrowprops"] = self.arrow_args kwargs['bbox']['fc'] = 'grey' ax.annotate("\n (...) \n", xy_parent, xy, **kwargs) From 27a29acceb014ab2107bf3ccbe15d5aa47ca643c Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 10 Jul 2017 14:49:06 -0400 Subject: [PATCH 20/61] use rendered stuff for setting limits (well nearly there) --- sklearn/tree/export.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 3fa0081cff3be..862ec24f0f858 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -544,7 +544,7 @@ def export(self, decision_tree, ax=None): from matplotlib.text import Annotation if ax is None: ax = plt.gca() - ax.set_axis_off() + # ax.set_axis_off() my_tree = self._make_tree(0, decision_tree.tree_) dt = buchheim(my_tree) self.recurse(dt, decision_tree.tree_, ax) @@ -552,35 +552,38 @@ def export(self, decision_tree, ax=None): anns = [ann for ann in ax.get_children() if isinstance(ann, Annotation)] - # get all the annotated points - xys = [ann.xyann for ann in anns] - - # set axis limits with slight margin of .5 - mins = np.min(xys, axis=0) - maxs = np.max(xys, axis=0) + .5 + # update sizes of all bboxes + renderer = ax.figure.canvas.get_renderer() + for ann in anns: + ann.update_bbox_position_size(renderer) - ax.set_xlim(mins[0], maxs[0]) - ax.set_ylim(maxs[1], mins[1]) + # get figure to data transform + inv = ax.transData.inverted() if self.fontsize is None: - # get figure to data transform - inv = ax.transData.inverted() - renderer = ax.figure.canvas.get_renderer() - # update sizes of all bboxes - for ann in anns: - ann.update_bbox_position_size(renderer) + # adjust fontsize to avoid overlap # get max box width widths = [inv.get_matrix()[0, 0] * ann.get_bbox_patch().get_window_extent().width for ann in anns] # get minimum max size to not be too big. - max_width = max(max(widths), 1) - # adjust fontsize to avoid overlap + max_width = max(widths) # width should be around 1 in data coordinates size = anns[0].get_fontsize() / max_width for ann in anns: ann.set_fontsize(size) + # bboxes = [inv.transform(ann.get_bbox_patch().get_bbox()) for ann in anns] + # get all the annotated points + xys = [ann.xyann for ann in anns] + + # set axis limits with slight margin of .5 + mins = np.min(xys, axis=0) + maxs = np.max(xys, axis=0) + + ax.set_xlim(mins[0], maxs[0]) + ax.set_ylim(maxs[1], mins[1]) + def recurse(self, node, tree, ax, depth=0): kwargs = dict(bbox=self.bbox_args, ha='center', va='center', zorder=100 - 10 * depth) From 538d25753ab142a3d8ad43906a4761156d6eeaa6 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 12 Jul 2017 21:39:29 -0500 Subject: [PATCH 21/61] import plot_tree into tree module --- sklearn/tree/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/__init__.py b/sklearn/tree/__init__.py index 1394bd914d27c..5b3c66a5f11e6 100644 --- a/sklearn/tree/__init__.py +++ b/sklearn/tree/__init__.py @@ -7,7 +7,7 @@ from .tree import DecisionTreeRegressor from .tree import ExtraTreeClassifier from .tree import ExtraTreeRegressor -from .export import export_graphviz +from .export import export_graphviz, plot_tree __all__ = ["DecisionTreeClassifier", "DecisionTreeRegressor", - "ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz"] + "ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz", "plot_tree"] From c6ecbb2986f86975fc8ba24c7c02c99f3d458ab9 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 12 Jul 2017 21:40:15 -0500 Subject: [PATCH 22/61] set limits before font size adjustment? --- sklearn/tree/export.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 78d59ee9b7a78..b65446c3b71a8 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -544,7 +544,7 @@ def export(self, decision_tree, ax=None): from matplotlib.text import Annotation if ax is None: ax = plt.gca() - # ax.set_axis_off() + ax.set_axis_off() my_tree = self._make_tree(0, decision_tree.tree_) dt = buchheim(my_tree) self.recurse(dt, decision_tree.tree_, ax) @@ -557,10 +557,19 @@ def export(self, decision_tree, ax=None): for ann in anns: ann.update_bbox_position_size(renderer) - # get figure to data transform - inv = ax.transData.inverted() + # get all the annotated points + xys = [ann.xyann for ann in anns] + + # set axis limits + mins = np.min(xys, axis=0) + maxs = np.max(xys, axis=0) + + ax.set_xlim(mins[0], maxs[0]) + ax.set_ylim(maxs[1], mins[1]) if self.fontsize is None: + # get figure to data transform + inv = ax.transData.inverted() # adjust fontsize to avoid overlap # get max box width widths = [inv.get_matrix()[0, 0] @@ -573,17 +582,6 @@ def export(self, decision_tree, ax=None): for ann in anns: ann.set_fontsize(size) - # bboxes = [inv.transform(ann.get_bbox_patch().get_bbox()) for ann in anns] - # get all the annotated points - xys = [ann.xyann for ann in anns] - - # set axis limits with slight margin of .5 - mins = np.min(xys, axis=0) - maxs = np.max(xys, axis=0) - - ax.set_xlim(mins[0], maxs[0]) - ax.set_ylim(maxs[1], mins[1]) - def recurse(self, node, tree, ax, depth=0): kwargs = dict(bbox=self.bbox_args, ha='center', va='center', zorder=100 - 10 * depth) From fc7bdbe7c8dd5397dce3adf26a477e12374a8c8f Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 12 Jul 2017 21:42:31 -0500 Subject: [PATCH 23/61] add tree plotting via matplotlib to iris example and to docs --- doc/modules/tree.rst | 12 +++++++++++- examples/tree/plot_iris.py | 8 +++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/doc/modules/tree.rst b/doc/modules/tree.rst index f793c34b7f53d..922fcc17e2d88 100644 --- a/doc/modules/tree.rst +++ b/doc/modules/tree.rst @@ -124,7 +124,17 @@ Using the Iris dataset, we can construct a tree as follows:: >>> clf = tree.DecisionTreeClassifier() >>> clf = clf.fit(iris.data, iris.target) -Once trained, we can export the tree in `Graphviz +Once trained, you can plot the tree with the plot_tree function:: + + + >>> tree.plot_tree(clf.fit(iris.data, iris.target)) + +.. figure:: ../auto_examples/tree/images/sphx_glr_plot_iris_002.png + :target: ../auto_examples/tree/plot_iris.html + :scale: 75 + :align: center + +We can also export the tree in `Graphviz `_ format using the :func:`export_graphviz` exporter. If you use the `conda `_ package manager, the graphviz binaries and the python package can be installed with diff --git a/examples/tree/plot_iris.py b/examples/tree/plot_iris.py index d1b6e25b59a1c..97242db74bf49 100644 --- a/examples/tree/plot_iris.py +++ b/examples/tree/plot_iris.py @@ -11,6 +11,8 @@ For each pair of iris features, the decision tree learns decision boundaries made of combinations of simple thresholding rules inferred from the training samples. + +We also show the tree structure of a model build on all of the features. """ print(__doc__) @@ -18,7 +20,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import load_iris -from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeClassifier, plot_tree # Parameters n_classes = 3 @@ -63,4 +65,8 @@ plt.suptitle("Decision surface of a decision tree using paired features") plt.legend() + +plt.figure() +clf = DecisionTreeClassifier().fit(iris.data, iris.target) +plot_tree(clf, filled=True) plt.show() From 9d672ab73f1ce6656662e0a8d0a538463e7bdf81 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 13 Jul 2017 10:04:30 -0500 Subject: [PATCH 24/61] pep8 fix --- sklearn/tree/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/__init__.py b/sklearn/tree/__init__.py index 5b3c66a5f11e6..b3abe30d019fa 100644 --- a/sklearn/tree/__init__.py +++ b/sklearn/tree/__init__.py @@ -10,4 +10,5 @@ from .export import export_graphviz, plot_tree __all__ = ["DecisionTreeClassifier", "DecisionTreeRegressor", - "ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz", "plot_tree"] + "ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz", + "plot_tree"] From 1c8b8d6a9ed0184e61c962b117f5fb5e76c748e7 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 13 Jul 2017 10:05:42 -0500 Subject: [PATCH 25/61] skip doctest on plot_tree because matplotlib is not installed on all CI machines --- doc/modules/tree.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/modules/tree.rst b/doc/modules/tree.rst index 922fcc17e2d88..8630f52c7e277 100644 --- a/doc/modules/tree.rst +++ b/doc/modules/tree.rst @@ -127,7 +127,7 @@ Using the Iris dataset, we can construct a tree as follows:: Once trained, you can plot the tree with the plot_tree function:: - >>> tree.plot_tree(clf.fit(iris.data, iris.target)) + >>> tree.plot_tree(clf.fit(iris.data, iris.target)) # doctest: +SKIP .. figure:: ../auto_examples/tree/images/sphx_glr_plot_iris_002.png :target: ../auto_examples/tree/plot_iris.html From 474c557224a12f3a30a532871d838ff8c1a57c62 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 14 Jul 2017 17:01:04 -0500 Subject: [PATCH 26/61] redo everything in axis pixel coordinates re-introduce scalex, scaley add max_extents to tree to get tree size before plotting --- sklearn/tree/_reingold_tilford.py | 7 +++++ sklearn/tree/export.py | 45 ++++++++++++++++--------------- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/sklearn/tree/_reingold_tilford.py b/sklearn/tree/_reingold_tilford.py index 9e710dc75bca7..e62ffeaa6797f 100644 --- a/sklearn/tree/_reingold_tilford.py +++ b/sklearn/tree/_reingold_tilford.py @@ -1,5 +1,7 @@ # taken from https://github.com/llimllib/pymag-trees/blob/master/buchheim.py +import numpy as np + class DrawTree(object): def __init__(self, tree, parent=None, depth=0, number=1): @@ -47,6 +49,11 @@ def __str__(self): def __repr__(self): return self.__str__() + def max_extents(self): + extents = [c.max_extents() for c in self. children] + extents.append((self.x, self.y)) + return np.max(extents, axis=0) + def buchheim(tree): dt = firstwalk(DrawTree(tree)) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index b65446c3b71a8..04c7d06cd439c 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -547,7 +547,19 @@ def export(self, decision_tree, ax=None): ax.set_axis_off() my_tree = self._make_tree(0, decision_tree.tree_) dt = buchheim(my_tree) - self.recurse(dt, decision_tree.tree_, ax) + + # important to make sure we're still + # inside the axis after drawing the box + # this makes sense because the width of a box + # is about the same as the distance between boxes + max_x, max_y = dt.max_extents() + 1 + ax_width = ax.get_window_extent().width + ax_height = ax.get_window_extent().height + + scale_x = ax_width / max_x + scale_y = ax_height / max_y + + self.recurse(dt, decision_tree.tree_, ax, scale_x, scale_y, ax_height) anns = [ann for ann in ax.get_children() if isinstance(ann, Annotation)] @@ -557,39 +569,28 @@ def export(self, decision_tree, ax=None): for ann in anns: ann.update_bbox_position_size(renderer) - # get all the annotated points - xys = [ann.xyann for ann in anns] - - # set axis limits - mins = np.min(xys, axis=0) - maxs = np.max(xys, axis=0) - - ax.set_xlim(mins[0], maxs[0]) - ax.set_ylim(maxs[1], mins[1]) - if self.fontsize is None: # get figure to data transform - inv = ax.transData.inverted() # adjust fontsize to avoid overlap # get max box width - widths = [inv.get_matrix()[0, 0] - * ann.get_bbox_patch().get_window_extent().width + widths = [ann.get_bbox_patch().get_window_extent().width for ann in anns] # get minimum max size to not be too big. max_width = max(widths) - # width should be around 1 in data coordinates - size = anns[0].get_fontsize() / max_width + # width should be around scale_x in axis coordinates + size = anns[0].get_fontsize() / max_width * scale_x for ann in anns: ann.set_fontsize(size) - def recurse(self, node, tree, ax, depth=0): + def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0): kwargs = dict(bbox=self.bbox_args, ha='center', va='center', - zorder=100 - 10 * depth) + zorder=100 - 10 * depth, xycoords='axes pixels') if self.fontsize is not None: kwargs['fontsize'] = self.fontsize - xy = (node.x, node.y) + # offset things by .5 to center them in plot + xy = ((node.x + .5) * scale_x, height - (node.y + .5) * scale_y) if self.max_depth is None or depth <= self.max_depth: if self.filled: @@ -599,11 +600,13 @@ def recurse(self, node, tree, ax, depth=0): # root ax.annotate(node.tree.node, xy, **kwargs) else: - xy_parent = (node.parent.x, node.parent.y) + xy_parent = ((node.parent.x + .5) * scale_x, + height - (node.parent.y + .5) * scale_y) kwargs["arrowprops"] = self.arrow_args ax.annotate(node.tree.node, xy_parent, xy, **kwargs) for child in node.children: - self.recurse(child, tree, ax, depth=depth + 1) + self.recurse(child, tree, ax, scale_x, scale_y, height, + depth=depth + 1) else: xy_parent = (node.parent.x, node.parent.y) From 4c97f37effb0d4b0997e1727e4ccde30a68fa493 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 14 Jul 2017 17:07:43 -0500 Subject: [PATCH 27/61] fix max-depth parent node positioning and don't consider deep nodes in layouting --- sklearn/tree/export.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 04c7d06cd439c..6b49a2b94edae 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -528,13 +528,16 @@ def __init__(self, max_depth=None, feature_names=None, self.arrow_args = dict(arrowstyle="<-") - def _make_tree(self, node_id, et): + def _make_tree(self, node_id, et, depth=0): # traverses _tree.Tree recursively, builds intermediate # "_reingold_tilford.Tree" object name = self.node_to_str(et, node_id, criterion='entropy') - if (et.children_left[node_id] != et.children_right[node_id]): - children = [self._make_tree(et.children_left[node_id], et), - self._make_tree(et.children_right[node_id], et)] + if (et.children_left[node_id] != et.children_right[node_id] + and depth <= self.max_depth): + children = [self._make_tree(et.children_left[node_id], et, + depth=depth + 1), + self._make_tree(et.children_right[node_id], et, + depth=depth + 1)] else: return Tree(name, node_id) return Tree(name, node_id, *children) @@ -609,7 +612,8 @@ def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0): depth=depth + 1) else: - xy_parent = (node.parent.x, node.parent.y) + xy_parent = ((node.parent.x + .5) * scale_x, + height - (node.parent.y + .5) * scale_y) kwargs["arrowprops"] = self.arrow_args kwargs['bbox']['fc'] = 'grey' ax.annotate("\n (...) \n", xy_parent, xy, **kwargs) From b31e7ecccca61bb51994fed6ed22f5b1f51af446 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 14 Jul 2017 17:12:53 -0500 Subject: [PATCH 28/61] consider height in fontsize computation in case someone gave us a very flat figure --- sklearn/tree/export.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 6b49a2b94edae..11686982aaeac 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -575,13 +575,14 @@ def export(self, decision_tree, ax=None): if self.fontsize is None: # get figure to data transform # adjust fontsize to avoid overlap - # get max box width - widths = [ann.get_bbox_patch().get_window_extent().width - for ann in anns] - # get minimum max size to not be too big. - max_width = max(widths) + # get max box width and height + extents = [ann.get_bbox_patch().get_window_extent() + for ann in anns] + max_width = max([extent.width for extent in extents]) + max_height = max([extent.height for extent in extents]) # width should be around scale_x in axis coordinates - size = anns[0].get_fontsize() / max_width * scale_x + size = anns[0].get_fontsize() * min(scale_x / max_width, + scale_y / max_height) for ann in anns: ann.set_fontsize(size) From 9f3664882cb3e41881ce0a51270b758f27bc1217 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Sat, 15 Jul 2017 16:05:44 -0500 Subject: [PATCH 29/61] fix error when max_depth is None --- sklearn/tree/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 11686982aaeac..b20821925149d 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -533,7 +533,7 @@ def _make_tree(self, node_id, et, depth=0): # "_reingold_tilford.Tree" object name = self.node_to_str(et, node_id, criterion='entropy') if (et.children_left[node_id] != et.children_right[node_id] - and depth <= self.max_depth): + and (self.max_depth is None or depth <= self.max_depth)): children = [self._make_tree(et.children_left[node_id], et, depth=depth + 1), self._make_tree(et.children_right[node_id], et, From 752135ee75a33da3ce72ec84838814dc08c38dbe Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Tue, 21 Nov 2017 17:29:07 -0500 Subject: [PATCH 30/61] add docstring for tree plotting fontsize --- sklearn/tree/export.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index e166c1b26d9d2..f4bd28025cfe6 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -144,6 +144,10 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, ax : matplotlib axis, optional (default=None) Axes to plot to. If None, use current axis. + fontsize : int, optional (default=None) + Size of text font. If None, determined automatically to fit figure. + + Examples -------- >>> from sklearn.datasets import load_iris From fe92f7416ff91dfcb0361571880a8ec654b5fd90 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 28 Jun 2018 13:29:27 -0400 Subject: [PATCH 31/61] starting on jnothman's review --- examples/tree/plot_iris.py | 2 +- sklearn/tree/_reingold_tilford.py | 11 +- sklearn/tree/export.py | 25 +- tree_plotting.py | 561 ++++++++++++++++++++++++++++++ 4 files changed, 581 insertions(+), 18 deletions(-) create mode 100644 tree_plotting.py diff --git a/examples/tree/plot_iris.py b/examples/tree/plot_iris.py index 6d1ab10713000..60328c4f90d4f 100644 --- a/examples/tree/plot_iris.py +++ b/examples/tree/plot_iris.py @@ -12,7 +12,7 @@ boundaries made of combinations of simple thresholding rules inferred from the training samples. -We also show the tree structure of a model build on all of the features. +We also show the tree structure of a model built on all of the features. """ print(__doc__) diff --git a/sklearn/tree/_reingold_tilford.py b/sklearn/tree/_reingold_tilford.py index e62ffeaa6797f..0ca79d89887b3 100644 --- a/sklearn/tree/_reingold_tilford.py +++ b/sklearn/tree/_reingold_tilford.py @@ -56,7 +56,7 @@ def max_extents(self): def buchheim(tree): - dt = firstwalk(DrawTree(tree)) + dt = first_walk(DrawTree(tree)) min = second_walk(dt) if min < 0: third_walk(dt, -min) @@ -69,7 +69,7 @@ def third_walk(tree, n): third_walk(c, n) -def firstwalk(v, distance=1.): +def first_walk(v, distance=1.): if len(v.children) == 0: if v.lmost_sibling: v.x = v.lbrother().x + distance @@ -78,7 +78,7 @@ def firstwalk(v, distance=1.): else: default_ancestor = v.children[0] for w in v.children: - firstwalk(w) + first_walk(w) default_ancestor = apportion(w, default_ancestor, distance) # print("finished v =", v.tree, "children") execute_shifts(v) @@ -178,9 +178,8 @@ def second_walk(v, m=0, depth=0, min=None): class Tree(object): - def __init__(self, node="", node_id=-1, *children): - self.node = node - self.width = len(node) + def __init__(self, label="", node_id=-1, *children): + self.label = label self.node_id = node_id if children: self.children = children diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 5f4d237247472..d73581e0983f7 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -146,6 +146,10 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, fontsize : int, optional (default=None) Size of text font. If None, determined automatically to fit figure. + Returns + ------- + annotations : list of artists + List containing the artists for the annotation boxes making up the tree. Examples -------- @@ -156,7 +160,8 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, >>> iris = load_iris() >>> clf = clf.fit(iris.data, iris.target) - >>> tree.plot_tree(clf) # doctest: +SKIP + .. plot:: + >>> tree.plot_tree(clf) """ exporter = _MPLTreeExporter( @@ -166,7 +171,7 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, proportion=proportion, rotate=rotate, rounded=rounded, special_characters=special_characters, precision=precision, fontsize=fontsize) - exporter.export(decision_tree, ax=ax) + return exporter.export(decision_tree, ax=ax) class _BaseTreeExporter(object): @@ -191,11 +196,7 @@ def get_color(self, value): # compute the color as alpha against white color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color] # Return html color code in #RRGGBB format - hex_codes = [str(i) for i in range(10)] - hex_codes.extend(['a', 'b', 'c', 'd', 'e', 'f']) - color = [hex_codes[c // 16] + hex_codes[c % 16] for c in color] - - return '#' + ''.join(color) + return '#%2x%2x%2x' % tuple(color) def get_fill_color(self, tree, node_id): # Fetch appropriate color for node @@ -535,7 +536,7 @@ def _make_tree(self, node_id, et, depth=0): # traverses _tree.Tree recursively, builds intermediate # "_reingold_tilford.Tree" object name = self.node_to_str(et, node_id, criterion='entropy') - if (et.children_left[node_id] != et.children_right[node_id] + if (et.children_left[node_id] != _tree.LEAF and (self.max_depth is None or depth <= self.max_depth)): children = [self._make_tree(et.children_left[node_id], et, depth=depth + 1), @@ -552,20 +553,21 @@ def export(self, decision_tree, ax=None): ax = plt.gca() ax.set_axis_off() my_tree = self._make_tree(0, decision_tree.tree_) - dt = buchheim(my_tree) + draw_tree = buchheim(my_tree) # important to make sure we're still # inside the axis after drawing the box # this makes sense because the width of a box # is about the same as the distance between boxes - max_x, max_y = dt.max_extents() + 1 + max_x, max_y = draw_tree.max_extents() + 1 ax_width = ax.get_window_extent().width ax_height = ax.get_window_extent().height scale_x = ax_width / max_x scale_y = ax_height / max_y - self.recurse(dt, decision_tree.tree_, ax, scale_x, scale_y, ax_height) + self.recurse(draw_tree, decision_tree.tree_, ax, + scale_x, scale_y, ax_height) anns = [ann for ann in ax.get_children() if isinstance(ann, Annotation)] @@ -588,6 +590,7 @@ def export(self, decision_tree, ax=None): scale_y / max_height) for ann in anns: ann.set_fontsize(size) + return anns def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0): kwargs = dict(bbox=self.bbox_args, ha='center', va='center', diff --git a/tree_plotting.py b/tree_plotting.py new file mode 100644 index 0000000000000..19d86a53c264b --- /dev/null +++ b/tree_plotting.py @@ -0,0 +1,561 @@ +import numpy as np +from numbers import Integral + +from sklearn.externals import six +from sklearn.tree.export import _color_brew, _criterion, _tree + + +def plot_tree(decision_tree, max_depth=None, feature_names=None, + class_names=None, label='all', filled=False, + leaves_parallel=False, impurity=True, node_ids=False, + proportion=False, rotate=False, rounded=False, + special_characters=False, precision=3, ax=None, fontsize=None): + """Plot a decision tree. + + The sample counts that are shown are weighted with any sample_weights that + might be present. + + Parameters + ---------- + decision_tree : decision tree classifier + The decision tree to be exported to GraphViz. + + max_depth : int, optional (default=None) + The maximum depth of the representation. If None, the tree is fully + generated. + + feature_names : list of strings, optional (default=None) + Names of each of the features. + + class_names : list of strings, bool or None, optional (default=None) + Names of each of the target classes in ascending numerical order. + Only relevant for classification and not supported for multi-output. + If ``True``, shows a symbolic representation of the class name. + + label : {'all', 'root', 'none'}, optional (default='all') + Whether to show informative labels for impurity, etc. + Options include 'all' to show at every node, 'root' to show only at + the top root node, or 'none' to not show at any node. + + filled : bool, optional (default=False) + When set to ``True``, paint nodes to indicate majority class for + classification, extremity of values for regression, or purity of node + for multi-output. + + leaves_parallel : bool, optional (default=False) + When set to ``True``, draw all leaf nodes at the bottom of the tree. + + impurity : bool, optional (default=True) + When set to ``True``, show the impurity at each node. + + node_ids : bool, optional (default=False) + When set to ``True``, show the ID number on each node. + + proportion : bool, optional (default=False) + When set to ``True``, change the display of 'values' and/or 'samples' + to be proportions and percentages respectively. + + rotate : bool, optional (default=False) + When set to ``True``, orient tree left to right rather than top-down. + + rounded : bool, optional (default=False) + When set to ``True``, draw node boxes with rounded corners and use + Helvetica fonts instead of Times-Roman. + + special_characters : bool, optional (default=False) + When set to ``False``, ignore special characters for PostScript + compatibility. + + precision : int, optional (default=3) + Number of digits of precision for floating point in the values of + impurity, threshold and value attributes of each node. + + ax : matplotlib axis, optional (default=None) + Axes to plot to. If None, use current axis. + + Examples + -------- + >>> from sklearn.datasets import load_iris + + >>> clf = tree.DecisionTreeClassifier() + >>> iris = load_iris() + + >>> clf = clf.fit(iris.data, iris.target) + >>> plot_tree(clf) # doctest: +SKIP + + """ + exporter = _MPLTreeExporter( + max_depth=max_depth, feature_names=feature_names, + class_names=class_names, label=label, filled=filled, + leaves_parallel=leaves_parallel, impurity=impurity, node_ids=node_ids, + proportion=proportion, rotate=rotate, rounded=rounded, + special_characters=special_characters, precision=precision, + fontsize=fontsize) + exporter.export(decision_tree, ax=ax) + + +class _BaseTreeExporter(object): + def get_color(self, value): + # Find the appropriate color & intensity for a node + if self.colors['bounds'] is None: + # Classification tree + color = list(self.colors['rgb'][np.argmax(value)]) + sorted_values = sorted(value, reverse=True) + if len(sorted_values) == 1: + alpha = 0 + else: + alpha = ((sorted_values[0] - sorted_values[1]) + / (1 - sorted_values[1])) + else: + # Regression tree or multi-output + color = list(self.colors['rgb'][0]) + alpha = ((value - self.colors['bounds'][0]) / + (self.colors['bounds'][1] - self.colors['bounds'][0])) + # unpack numpy scalars + alpha = float(alpha) + # compute the color as alpha against white + color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color] + # Return html color code in #RRGGBB format + hex_codes = [str(i) for i in range(10)] + hex_codes.extend(['a', 'b', 'c', 'd', 'e', 'f']) + color = [hex_codes[c // 16] + hex_codes[c % 16] for c in color] + + return '#' + ''.join(color) + + def get_fill_color(self, tree, node_id): + # Fetch appropriate color for node + if 'rgb' not in self.colors: + # Initialize colors and bounds if required + self.colors['rgb'] = _color_brew(tree.n_classes[0]) + if tree.n_outputs != 1: + # Find max and min impurities for multi-output + self.colors['bounds'] = (np.min(-tree.impurity), + np.max(-tree.impurity)) + elif (tree.n_classes[0] == 1 and + len(np.unique(tree.value)) != 1): + # Find max and min values in leaf nodes for regression + self.colors['bounds'] = (np.min(tree.value), + np.max(tree.value)) + if tree.n_outputs == 1: + node_val = (tree.value[node_id][0, :] / + tree.weighted_n_node_samples[node_id]) + if tree.n_classes[0] == 1: + # Regression + node_val = tree.value[node_id][0, :] + else: + # If multi-output color node by impurity + node_val = -tree.impurity[node_id] + return self.get_color(node_val) + + def node_to_str(self, tree, node_id, criterion): + # Generate the node content string + if tree.n_outputs == 1: + value = tree.value[node_id][0, :] + else: + value = tree.value[node_id] + + # Should labels be shown? + labels = (self.label == 'root' and node_id == 0) or self.label == 'all' + + characters = self.characters + node_string = characters[-1] + + # Write node ID + if self.node_ids: + if labels: + node_string += 'node ' + node_string += characters[0] + str(node_id) + characters[4] + + # Write decision criteria + if tree.children_left[node_id] != _tree.TREE_LEAF: + # Always write node decision criteria, except for leaves + if self.feature_names is not None: + feature = self.feature_names[tree.feature[node_id]] + else: + feature = "X%s%s%s" % (characters[1], + tree.feature[node_id], + characters[2]) + node_string += '%s %s %s%s' % (feature, + characters[3], + round(tree.threshold[node_id], + self.precision), + characters[4]) + + # Write impurity + if self.impurity: + if isinstance(criterion, _criterion.FriedmanMSE): + criterion = "friedman_mse" + elif not isinstance(criterion, six.string_types): + criterion = "impurity" + if labels: + node_string += '%s = ' % criterion + node_string += (str(round(tree.impurity[node_id], self.precision)) + + characters[4]) + + # Write node sample count + if labels: + node_string += 'samples = ' + if self.proportion: + percent = (100. * tree.n_node_samples[node_id] / + float(tree.n_node_samples[0])) + node_string += (str(round(percent, 1)) + '%' + + characters[4]) + else: + node_string += (str(tree.n_node_samples[node_id]) + + characters[4]) + + # Write node class distribution / regression value + if self.proportion and tree.n_classes[0] != 1: + # For classification this will show the proportion of samples + value = value / tree.weighted_n_node_samples[node_id] + if labels: + node_string += 'value = ' + if tree.n_classes[0] == 1: + # Regression + value_text = np.around(value, self.precision) + elif self.proportion: + # Classification + value_text = np.around(value, self.precision) + elif np.all(np.equal(np.mod(value, 1), 0)): + # Classification without floating-point weights + value_text = value.astype(int) + else: + # Classification with floating-point weights + value_text = np.around(value, self.precision) + # Strip whitespace + value_text = str(value_text.astype('S32')).replace("b'", "'") + value_text = value_text.replace("' '", ", ").replace("'", "") + if tree.n_classes[0] == 1 and tree.n_outputs == 1: + value_text = value_text.replace("[", "").replace("]", "") + value_text = value_text.replace("\n ", characters[4]) + node_string += value_text + characters[4] + + # Write node majority class + if (self.class_names is not None and + tree.n_classes[0] != 1 and + tree.n_outputs == 1): + # Only done for single-output classification trees + if labels: + node_string += 'class = ' + if self.class_names is not True: + class_name = self.class_names[np.argmax(value)] + else: + class_name = "y%s%s%s" % (characters[1], + np.argmax(value), + characters[2]) + node_string += class_name + + # Clean up any trailing newlines + if node_string.endswith(characters[4]): + node_string = node_string[:-len(characters[4])] + + return node_string + characters[5] + + +class _MPLTreeExporter(_BaseTreeExporter): + def __init__(self, max_depth=None, feature_names=None, + class_names=None, label='all', filled=False, + leaves_parallel=False, impurity=True, node_ids=False, + proportion=False, rotate=False, rounded=False, + special_characters=False, precision=3, fontsize=None): + self.max_depth = max_depth + self.feature_names = feature_names + self.class_names = class_names + self.label = label + self.filled = filled + self.leaves_parallel = leaves_parallel + self.impurity = impurity + self.node_ids = node_ids + self.proportion = proportion + self.rotate = rotate + self.rounded = rounded + self.special_characters = special_characters + self.precision = precision + self.fontsize = fontsize + self._scaley = 10 + + # validate + if isinstance(precision, Integral): + if precision < 0: + raise ValueError("'precision' should be greater or equal to 0." + " Got {} instead.".format(precision)) + else: + raise ValueError("'precision' should be an integer. Got {}" + " instead.".format(type(precision))) + + # The depth of each node for plotting with 'leaf' option + self.ranks = {'leaves': []} + # The colors to render each node with + self.colors = {'bounds': None} + + self.characters = ['#', '[', ']', '<=', '\n', '', ''] + + self.bbox_args = dict(fc='w') + if self.rounded: + self.bbox_args['boxstyle'] = "round" + self.arrow_args = dict(arrowstyle="<-") + + def _make_tree(self, node_id, et): + # traverses _tree.Tree recursively, builds intermediate + # "_reingold_tilford.Tree" object + name = self.node_to_str(et, node_id, criterion='entropy') + if (et.children_left[node_id] != et.children_right[node_id]): + children = [self._make_tree(et.children_left[node_id], et), + self._make_tree(et.children_right[node_id], et)] + else: + return Tree(name, node_id) + return Tree(name, node_id, *children) + + def export(self, decision_tree, ax=None): + import matplotlib.pyplot as plt + from matplotlib.text import Annotation + + if ax is None: + ax = plt.gca() + ax.set_axis_off() + my_tree = self._make_tree(0, decision_tree.tree_) + dt = buchheim(my_tree) + self._scalex = 1 + self.recurse(dt, decision_tree.tree_, ax) + + anns = [ann for ann in ax.get_children() + if isinstance(ann, Annotation)] + + # get all the annotated points + xys = [ann.xyann for ann in anns] + + mins = np.min(xys, axis=0) + maxs = np.max(xys, axis=0) + + ax.set_xlim(mins[0], maxs[0]) + ax.set_ylim(maxs[1], mins[1]) + + if self.fontsize is None: + # get figure to data transform + inv = ax.transData.inverted() + renderer = ax.figure.canvas.get_renderer() + # update sizes of all bboxes + for ann in anns: + ann.update_bbox_position_size(renderer) + # get max box width + widths = [inv.get_matrix()[0, 0] + * ann.get_bbox_patch().get_window_extent().width + for ann in anns] + # get minimum max size to not be too big. + max_width = max(max(widths), 1) + # adjust fontsize to avoid overlap + # width should be around 1 in data coordinates + size = anns[0].get_fontsize() / max_width + for ann in anns: + ann.set_fontsize(size) + + def recurse(self, node, tree, ax, depth=0): + kwargs = dict(bbox=self.bbox_args, ha='center', va='center', + zorder=100 - 10 * depth) + + if self.fontsize is not None: + kwargs['fontsize'] = self.fontsize + + xy = (node.x * self._scalex, node.y * self._scaley) + + if self.max_depth is None or depth <= self.max_depth: + if self.filled: + kwargs['bbox']['fc'] = self.get_fill_color(tree, + node.tree.node_id) + if node.parent is None: + # root + ax.annotate(node.tree.node, xy, **kwargs) + else: + xy_parent = (node.parent.x * self._scalex, + node.parent.y * self._scaley) + kwargs["arrowprops"] = self.arrow_args + ax.annotate(node.tree.node, xy_parent, xy, **kwargs) + for child in node.children: + self.recurse(child, tree, ax, depth=depth + 1) + + else: + xy_parent = (node.parent.x * self._scalex, node.parent.y * + self._scaley) + kwargs["arrowprops"] = self.arrow_args + kwargs['bbox']['fc'] = 'grey' + ax.annotate("\n (...) \n", xy_parent, xy, **kwargs) + + +class DrawTree(object): + def __init__(self, tree, parent=None, depth=0, number=1): + self.x = -1. + self.y = depth + self.tree = tree + self.children = [DrawTree(c, self, depth + 1, i + 1) + for i, c + in enumerate(tree.children)] + self.parent = parent + self.thread = None + self.mod = 0 + self.ancestor = self + self.change = self.shift = 0 + self._lmost_sibling = None + # this is the number of the node in its group of siblings 1..n + self.number = number + + def left(self): + return self.thread or len(self.children) and self.children[0] + + def right(self): + return self.thread or len(self.children) and self.children[-1] + + def lbrother(self): + n = None + if self.parent: + for node in self.parent.children: + if node == self: + return n + else: + n = node + return n + + def get_lmost_sibling(self): + if not self._lmost_sibling and self.parent and self != \ + self.parent.children[0]: + self._lmost_sibling = self.parent.children[0] + return self._lmost_sibling + lmost_sibling = property(get_lmost_sibling) + + def __str__(self): + return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod) + + def __repr__(self): + return self.__str__() + + +def buchheim(tree): + dt = firstwalk(DrawTree(tree)) + min = second_walk(dt) + if min < 0: + third_walk(dt, -min) + return dt + + +def third_walk(tree, n): + tree.x += n + for c in tree.children: + third_walk(c, n) + + +def firstwalk(v, distance=1.): + if len(v.children) == 0: + if v.lmost_sibling: + v.x = v.lbrother().x + distance + else: + v.x = 0. + else: + default_ancestor = v.children[0] + for w in v.children: + firstwalk(w) + default_ancestor = apportion(w, default_ancestor, distance) + # print("finished v =", v.tree, "children") + execute_shifts(v) + + midpoint = (v.children[0].x + v.children[-1].x) / 2 + + w = v.lbrother() + if w: + v.x = w.x + distance + v.mod = v.x - midpoint + else: + v.x = midpoint + return v + + +def apportion(v, default_ancestor, distance): + w = v.lbrother() + if w is not None: + # in buchheim notation: + # i == inner; o == outer; r == right; l == left; r = +; l = - + vir = vor = v + vil = w + vol = v.lmost_sibling + sir = sor = v.mod + sil = vil.mod + sol = vol.mod + while vil.right() and vir.left(): + vil = vil.right() + vir = vir.left() + vol = vol.left() + vor = vor.right() + vor.ancestor = v + shift = (vil.x + sil) - (vir.x + sir) + distance + if shift > 0: + move_subtree(ancestor(vil, v, default_ancestor), v, shift) + sir = sir + shift + sor = sor + shift + sil += vil.mod + sir += vir.mod + sol += vol.mod + sor += vor.mod + if vil.right() and not vor.right(): + vor.thread = vil.right() + vor.mod += sil - sor + else: + if vir.left() and not vol.left(): + vol.thread = vir.left() + vol.mod += sir - sol + default_ancestor = v + return default_ancestor + + +def move_subtree(wl, wr, shift): + subtrees = wr.number - wl.number + # print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees, + # 'shift', shift) + # print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees + wr.change -= shift / subtrees + wr.shift += shift + wl.change += shift / subtrees + wr.x += shift + wr.mod += shift + + +def execute_shifts(v): + shift = change = 0 + for w in v.children[::-1]: + # print("shift:", w, shift, w.change) + w.x += shift + w.mod += shift + change += w.change + shift += w.shift + change + + +def ancestor(vil, v, default_ancestor): + # the relevant text is at the bottom of page 7 of + # "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al, + # (2002) + # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.16.8757&rep=rep1&type=pdf + if vil.ancestor in v.parent.children: + return vil.ancestor + else: + return default_ancestor + + +def second_walk(v, m=0, depth=0, min=None): + v.x += m + v.y = depth + + if min is None or v.x < min: + min = v.x + + for w in v.children: + min = second_walk(w, m + v.mod, depth + 1, min) + + return min + + +class Tree(object): + def __init__(self, node="", node_id=-1, *children): + self.node = node + self.width = len(node) + self.node_id = node_id + if children: + self.children = children + else: + self.children = [] From 83a4a2e7496625d975b9f4981ee90b58f0933e13 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 28 Jun 2018 13:35:35 -0400 Subject: [PATCH 32/61] renaming fixes --- sklearn/tree/export.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index d73581e0983f7..d733d0c911513 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -85,6 +85,8 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, Read more in the :ref:`User Guide `. + .. versionadded:: 0.20 + Parameters ---------- decision_tree : decision tree regressor or classifier @@ -536,7 +538,7 @@ def _make_tree(self, node_id, et, depth=0): # traverses _tree.Tree recursively, builds intermediate # "_reingold_tilford.Tree" object name = self.node_to_str(et, node_id, criterion='entropy') - if (et.children_left[node_id] != _tree.LEAF + if (et.children_left[node_id] != _tree.TREE_LEAF and (self.max_depth is None or depth <= self.max_depth)): children = [self._make_tree(et.children_left[node_id], et, depth=depth + 1), @@ -608,12 +610,12 @@ def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0): node.tree.node_id) if node.parent is None: # root - ax.annotate(node.tree.node, xy, **kwargs) + ax.annotate(node.tree.label, xy, **kwargs) else: xy_parent = ((node.parent.x + .5) * scale_x, height - (node.parent.y + .5) * scale_y) kwargs["arrowprops"] = self.arrow_args - ax.annotate(node.tree.node, xy_parent, xy, **kwargs) + ax.annotate(node.tree.label, xy_parent, xy, **kwargs) for child in node.children: self.recurse(child, tree, ax, scale_x, scale_y, height, depth=depth + 1) From a1ba414a9545b2821e39b208f605106534fadffc Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 28 Jun 2018 13:38:35 -0400 Subject: [PATCH 33/61] whatsnew for tree plotting --- doc/whats_new/v0.20.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 802976c2b380c..7f212ed1b8bc8 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -178,6 +178,10 @@ Misc :func:`metrics.pairwise_distances_chunked`. See :ref:`working_memory`. :issue:`10280` by `Joel Nothman`_ and :user:`Aman Dalmia `. +- Decision Trees can now be plotted with matplotlib using :func:`tree.export.plot_tree` + without relying on the ``dot`` library, removing a hard-to-install dependency. + :issue:`8508` by `Andreas Müller`_. + Enhancements ............ From df6620dc9ab11be4a2506cefbe694d293602084b Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 28 Jun 2018 13:43:14 -0400 Subject: [PATCH 34/61] clear axes prior to doing anything. --- sklearn/tree/export.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index d733d0c911513..7931b48b21085 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -143,7 +143,8 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, impurity, threshold and value attributes of each node. ax : matplotlib axis, optional (default=None) - Axes to plot to. If None, use current axis. + Axes to plot to. If None, use current axis. Any previous content + is cleared. fontsize : int, optional (default=None) Size of text font. If None, determined automatically to fit figure. @@ -553,6 +554,7 @@ def export(self, decision_tree, ax=None): from matplotlib.text import Annotation if ax is None: ax = plt.gca() + ax.clear() ax.set_axis_off() my_tree = self._make_tree(0, decision_tree.tree_) draw_tree = buchheim(my_tree) From 02aeeab1ade9f24f448b3a221eeb7dcf8b05a0ca Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 28 Jun 2018 14:28:22 -0400 Subject: [PATCH 35/61] fix doctests --- sklearn/tree/export.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 7931b48b21085..a613667ccfea2 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -159,12 +159,12 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, >>> from sklearn.datasets import load_iris >>> from sklearn import tree - >>> clf = tree.DecisionTreeClassifier() + >>> clf = tree.DecisionTreeClassifier(random_state=0) >>> iris = load_iris() >>> clf = clf.fit(iris.data, iris.target) - .. plot:: - >>> tree.plot_tree(clf) + >>> tree.plot_tree(clf) # doctest: +ELLIPSIS + [Text(251.5,345.217,'X[3] <= 0.8... """ exporter = _MPLTreeExporter( @@ -728,8 +728,8 @@ def export_graphviz(decision_tree, out_file=None, max_depth=None, >>> iris = load_iris() >>> clf = clf.fit(iris.data, iris.target) - >>> tree.export_graphviz(clf) - + >>> tree.export_graphviz(clf) # doctest: +ELLIPSIS + 'digraph Tree {... """ check_is_fitted(decision_tree, 'tree_') From 072a66bc3d7903e574294ba147ff2222eb5e1214 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 28 Jun 2018 15:06:08 -0400 Subject: [PATCH 36/61] skip matplotlib doctest --- sklearn/tree/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index a613667ccfea2..29ea70abaf6c8 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -163,7 +163,7 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, >>> iris = load_iris() >>> clf = clf.fit(iris.data, iris.target) - >>> tree.plot_tree(clf) # doctest: +ELLIPSIS + >>> tree.plot_tree(clf) # doctest: +SKIP [Text(251.5,345.217,'X[3] <= 0.8... """ From 88cc0602a482a52ac6335dd67fb68b194a87afaf Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 28 Jun 2018 15:17:32 -0400 Subject: [PATCH 37/61] trying to debug circle failure --- sklearn/tree/export.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 29ea70abaf6c8..9a6db3e8c297a 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -578,6 +578,7 @@ def export(self, decision_tree, ax=None): # update sizes of all bboxes renderer = ax.figure.canvas.get_renderer() + print(renderer) # fixme worst debugging ever for ann in anns: ann.update_bbox_position_size(renderer) From 9eb0549cc44efccf7ddc957b4c524b2f97aa5b0e Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 29 Jun 2018 12:07:32 -0400 Subject: [PATCH 38/61] trying to show full traceback --- doc/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/Makefile b/doc/Makefile index 557eeaa188d2f..a6e7ab4e1043b 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -13,7 +13,7 @@ endif # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS)\ +ALLSPHINXOPTS = -T -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS)\ $(EXAMPLES_PATTERN_OPTS) . From 61dd70a19c42d274aff86c7e4f0ee9e0d71108e3 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 29 Jun 2018 12:52:42 -0400 Subject: [PATCH 39/61] more print debugging --- sklearn/tree/export.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 9a6db3e8c297a..2cdf782befd81 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -578,16 +578,22 @@ def export(self, decision_tree, ax=None): # update sizes of all bboxes renderer = ax.figure.canvas.get_renderer() + print("renderer") print(renderer) # fixme worst debugging ever + import warnings for ann in anns: + print(ann) + warnings.warn(ann) ann.update_bbox_position_size(renderer) if self.fontsize is None: # get figure to data transform # adjust fontsize to avoid overlap # get max box width and height + warnings.warn("dohh") + warnings.warn(renderer) extents = [ann.get_bbox_patch().get_window_extent() - for ann in anns] + for ann in anns] max_width = max([extent.width for extent in extents]) max_height = max([extent.height for extent in extents]) # width should be around scale_x in axis coordinates From 5d847438ad095c4108ba438ce614470b232af47f Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 29 Jun 2018 13:12:19 -0400 Subject: [PATCH 40/61] remove debugging crud --- sklearn/tree/export.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 2cdf782befd81..40388ef731931 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -578,22 +578,16 @@ def export(self, decision_tree, ax=None): # update sizes of all bboxes renderer = ax.figure.canvas.get_renderer() - print("renderer") - print(renderer) # fixme worst debugging ever - import warnings + for ann in anns: - print(ann) - warnings.warn(ann) ann.update_bbox_position_size(renderer) if self.fontsize is None: # get figure to data transform # adjust fontsize to avoid overlap # get max box width and height - warnings.warn("dohh") - warnings.warn(renderer) extents = [ann.get_bbox_patch().get_window_extent() - for ann in anns] + for ann in anns] max_width = max([extent.width for extent in extents]) max_height = max([extent.height for extent in extents]) # width should be around scale_x in axis coordinates From cf4a620f304f20760aac33135a337b7f88be4f31 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 29 Jun 2018 15:01:06 -0400 Subject: [PATCH 41/61] hack around matplotlib <1.5 issues --- sklearn/tree/export.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 40388ef731931..2bba28a235a60 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -10,6 +10,7 @@ # Trevor Stephens # Li Li # License: BSD 3 clause +import warnings from numbers import Integral @@ -586,15 +587,21 @@ def export(self, decision_tree, ax=None): # get figure to data transform # adjust fontsize to avoid overlap # get max box width and height - extents = [ann.get_bbox_patch().get_window_extent() - for ann in anns] - max_width = max([extent.width for extent in extents]) - max_height = max([extent.height for extent in extents]) - # width should be around scale_x in axis coordinates - size = anns[0].get_fontsize() * min(scale_x / max_width, - scale_y / max_height) - for ann in anns: - ann.set_fontsize(size) + try: + extents = [ann.get_bbox_patch().get_window_extent() + for ann in anns] + max_width = max([extent.width for extent in extents]) + max_height = max([extent.height for extent in extents]) + # width should be around scale_x in axis coordinates + size = anns[0].get_fontsize() * min(scale_x / max_width, + scale_y / max_height) + for ann in anns: + ann.set_fontsize(size) + except AttributeError: + # matplotlib < 1.5 + warnings.warn("Automatic scaling of tree plots requires " + "matplotlib 1.5 or higher. Please specify fontsize.") + return anns def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0): From db612491dba940703ac19de4390ec0e0f80dd52b Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 29 Jun 2018 15:29:45 -0400 Subject: [PATCH 42/61] copy bbox args because old matplotlib is weird. --- sklearn/tree/export.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 2bba28a235a60..6e09ccad0ee1a 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -605,7 +605,8 @@ def export(self, decision_tree, ax=None): return anns def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0): - kwargs = dict(bbox=self.bbox_args, ha='center', va='center', + # need to copy bbox args because matplotib <1.5 modifies them + kwargs = dict(bbox=self.bbox_args.copy(), ha='center', va='center', zorder=100 - 10 * depth, xycoords='axes pixels') if self.fontsize is not None: From 5beedf2de6963fadb5830e2fd3ae650fbcb89b9e Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 29 Jun 2018 15:37:25 -0400 Subject: [PATCH 43/61] pep8 fixes --- sklearn/tree/export.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 6e09ccad0ee1a..5198ab8188a15 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -589,7 +589,7 @@ def export(self, decision_tree, ax=None): # get max box width and height try: extents = [ann.get_bbox_patch().get_window_extent() - for ann in anns] + for ann in anns] max_width = max([extent.width for extent in extents]) max_height = max([extent.height for extent in extents]) # width should be around scale_x in axis coordinates @@ -600,7 +600,8 @@ def export(self, decision_tree, ax=None): except AttributeError: # matplotlib < 1.5 warnings.warn("Automatic scaling of tree plots requires " - "matplotlib 1.5 or higher. Please specify fontsize.") + "matplotlib 1.5 or higher. Please specify " + "fontsize.") return anns From 64d47ac2374fad39f273cc85c5b8cb92647b328a Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 29 Jun 2018 16:23:42 -0400 Subject: [PATCH 44/61] add explicit boxstyle --- sklearn/tree/export.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 5198ab8188a15..0ad5ec48f26ea 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -533,6 +533,9 @@ def __init__(self, max_depth=None, feature_names=None, self.bbox_args = dict(fc='w') if self.rounded: self.bbox_args['boxstyle'] = "round" + else: + # matplotlib <1.5 requires explicit boxstyle + self.bbox_args['boxstyle'] = "square" self.arrow_args = dict(arrowstyle="<-") From cbfce1c0ef1e8ea9b616f9124d4fe6005df6ad35 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 29 Jun 2018 16:25:15 -0400 Subject: [PATCH 45/61] more pep8 --- sklearn/tree/export.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 0ad5ec48f26ea..70f955b08f9d5 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -153,7 +153,8 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, Returns ------- annotations : list of artists - List containing the artists for the annotation boxes making up the tree. + List containing the artists for the annotation boxes making up the + tree. Examples -------- From f92aca1d63214e5d609ece7d800d26ef9a1b7676 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 29 Jun 2018 17:43:26 -0400 Subject: [PATCH 46/61] even more pep8 --- sklearn/tree/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 70f955b08f9d5..8a894ca79c397 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -149,7 +149,7 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, fontsize : int, optional (default=None) Size of text font. If None, determined automatically to fit figure. - + Returns ------- annotations : list of artists From d6cc6ada14040d828e285a4adcc2c3c9bc039f64 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Sat, 14 Jul 2018 11:35:09 -0500 Subject: [PATCH 47/61] add comment about matplotlib version requirement --- sklearn/tree/export.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 8a894ca79c397..1d062ba986a5d 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -83,6 +83,7 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, The sample counts that are shown are weighted with any sample_weights that might be present. + This function requires matplotlib, and works best with matplotlib >= 1.5. Read more in the :ref:`User Guide `. From daae2e2e783efa1ac752b36306fe3f35f91e6428 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 20 Aug 2018 15:51:25 -0400 Subject: [PATCH 48/61] remove redundant file --- tree_plotting.py | 561 ----------------------------------------------- 1 file changed, 561 deletions(-) delete mode 100644 tree_plotting.py diff --git a/tree_plotting.py b/tree_plotting.py deleted file mode 100644 index 19d86a53c264b..0000000000000 --- a/tree_plotting.py +++ /dev/null @@ -1,561 +0,0 @@ -import numpy as np -from numbers import Integral - -from sklearn.externals import six -from sklearn.tree.export import _color_brew, _criterion, _tree - - -def plot_tree(decision_tree, max_depth=None, feature_names=None, - class_names=None, label='all', filled=False, - leaves_parallel=False, impurity=True, node_ids=False, - proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3, ax=None, fontsize=None): - """Plot a decision tree. - - The sample counts that are shown are weighted with any sample_weights that - might be present. - - Parameters - ---------- - decision_tree : decision tree classifier - The decision tree to be exported to GraphViz. - - max_depth : int, optional (default=None) - The maximum depth of the representation. If None, the tree is fully - generated. - - feature_names : list of strings, optional (default=None) - Names of each of the features. - - class_names : list of strings, bool or None, optional (default=None) - Names of each of the target classes in ascending numerical order. - Only relevant for classification and not supported for multi-output. - If ``True``, shows a symbolic representation of the class name. - - label : {'all', 'root', 'none'}, optional (default='all') - Whether to show informative labels for impurity, etc. - Options include 'all' to show at every node, 'root' to show only at - the top root node, or 'none' to not show at any node. - - filled : bool, optional (default=False) - When set to ``True``, paint nodes to indicate majority class for - classification, extremity of values for regression, or purity of node - for multi-output. - - leaves_parallel : bool, optional (default=False) - When set to ``True``, draw all leaf nodes at the bottom of the tree. - - impurity : bool, optional (default=True) - When set to ``True``, show the impurity at each node. - - node_ids : bool, optional (default=False) - When set to ``True``, show the ID number on each node. - - proportion : bool, optional (default=False) - When set to ``True``, change the display of 'values' and/or 'samples' - to be proportions and percentages respectively. - - rotate : bool, optional (default=False) - When set to ``True``, orient tree left to right rather than top-down. - - rounded : bool, optional (default=False) - When set to ``True``, draw node boxes with rounded corners and use - Helvetica fonts instead of Times-Roman. - - special_characters : bool, optional (default=False) - When set to ``False``, ignore special characters for PostScript - compatibility. - - precision : int, optional (default=3) - Number of digits of precision for floating point in the values of - impurity, threshold and value attributes of each node. - - ax : matplotlib axis, optional (default=None) - Axes to plot to. If None, use current axis. - - Examples - -------- - >>> from sklearn.datasets import load_iris - - >>> clf = tree.DecisionTreeClassifier() - >>> iris = load_iris() - - >>> clf = clf.fit(iris.data, iris.target) - >>> plot_tree(clf) # doctest: +SKIP - - """ - exporter = _MPLTreeExporter( - max_depth=max_depth, feature_names=feature_names, - class_names=class_names, label=label, filled=filled, - leaves_parallel=leaves_parallel, impurity=impurity, node_ids=node_ids, - proportion=proportion, rotate=rotate, rounded=rounded, - special_characters=special_characters, precision=precision, - fontsize=fontsize) - exporter.export(decision_tree, ax=ax) - - -class _BaseTreeExporter(object): - def get_color(self, value): - # Find the appropriate color & intensity for a node - if self.colors['bounds'] is None: - # Classification tree - color = list(self.colors['rgb'][np.argmax(value)]) - sorted_values = sorted(value, reverse=True) - if len(sorted_values) == 1: - alpha = 0 - else: - alpha = ((sorted_values[0] - sorted_values[1]) - / (1 - sorted_values[1])) - else: - # Regression tree or multi-output - color = list(self.colors['rgb'][0]) - alpha = ((value - self.colors['bounds'][0]) / - (self.colors['bounds'][1] - self.colors['bounds'][0])) - # unpack numpy scalars - alpha = float(alpha) - # compute the color as alpha against white - color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color] - # Return html color code in #RRGGBB format - hex_codes = [str(i) for i in range(10)] - hex_codes.extend(['a', 'b', 'c', 'd', 'e', 'f']) - color = [hex_codes[c // 16] + hex_codes[c % 16] for c in color] - - return '#' + ''.join(color) - - def get_fill_color(self, tree, node_id): - # Fetch appropriate color for node - if 'rgb' not in self.colors: - # Initialize colors and bounds if required - self.colors['rgb'] = _color_brew(tree.n_classes[0]) - if tree.n_outputs != 1: - # Find max and min impurities for multi-output - self.colors['bounds'] = (np.min(-tree.impurity), - np.max(-tree.impurity)) - elif (tree.n_classes[0] == 1 and - len(np.unique(tree.value)) != 1): - # Find max and min values in leaf nodes for regression - self.colors['bounds'] = (np.min(tree.value), - np.max(tree.value)) - if tree.n_outputs == 1: - node_val = (tree.value[node_id][0, :] / - tree.weighted_n_node_samples[node_id]) - if tree.n_classes[0] == 1: - # Regression - node_val = tree.value[node_id][0, :] - else: - # If multi-output color node by impurity - node_val = -tree.impurity[node_id] - return self.get_color(node_val) - - def node_to_str(self, tree, node_id, criterion): - # Generate the node content string - if tree.n_outputs == 1: - value = tree.value[node_id][0, :] - else: - value = tree.value[node_id] - - # Should labels be shown? - labels = (self.label == 'root' and node_id == 0) or self.label == 'all' - - characters = self.characters - node_string = characters[-1] - - # Write node ID - if self.node_ids: - if labels: - node_string += 'node ' - node_string += characters[0] + str(node_id) + characters[4] - - # Write decision criteria - if tree.children_left[node_id] != _tree.TREE_LEAF: - # Always write node decision criteria, except for leaves - if self.feature_names is not None: - feature = self.feature_names[tree.feature[node_id]] - else: - feature = "X%s%s%s" % (characters[1], - tree.feature[node_id], - characters[2]) - node_string += '%s %s %s%s' % (feature, - characters[3], - round(tree.threshold[node_id], - self.precision), - characters[4]) - - # Write impurity - if self.impurity: - if isinstance(criterion, _criterion.FriedmanMSE): - criterion = "friedman_mse" - elif not isinstance(criterion, six.string_types): - criterion = "impurity" - if labels: - node_string += '%s = ' % criterion - node_string += (str(round(tree.impurity[node_id], self.precision)) - + characters[4]) - - # Write node sample count - if labels: - node_string += 'samples = ' - if self.proportion: - percent = (100. * tree.n_node_samples[node_id] / - float(tree.n_node_samples[0])) - node_string += (str(round(percent, 1)) + '%' + - characters[4]) - else: - node_string += (str(tree.n_node_samples[node_id]) + - characters[4]) - - # Write node class distribution / regression value - if self.proportion and tree.n_classes[0] != 1: - # For classification this will show the proportion of samples - value = value / tree.weighted_n_node_samples[node_id] - if labels: - node_string += 'value = ' - if tree.n_classes[0] == 1: - # Regression - value_text = np.around(value, self.precision) - elif self.proportion: - # Classification - value_text = np.around(value, self.precision) - elif np.all(np.equal(np.mod(value, 1), 0)): - # Classification without floating-point weights - value_text = value.astype(int) - else: - # Classification with floating-point weights - value_text = np.around(value, self.precision) - # Strip whitespace - value_text = str(value_text.astype('S32')).replace("b'", "'") - value_text = value_text.replace("' '", ", ").replace("'", "") - if tree.n_classes[0] == 1 and tree.n_outputs == 1: - value_text = value_text.replace("[", "").replace("]", "") - value_text = value_text.replace("\n ", characters[4]) - node_string += value_text + characters[4] - - # Write node majority class - if (self.class_names is not None and - tree.n_classes[0] != 1 and - tree.n_outputs == 1): - # Only done for single-output classification trees - if labels: - node_string += 'class = ' - if self.class_names is not True: - class_name = self.class_names[np.argmax(value)] - else: - class_name = "y%s%s%s" % (characters[1], - np.argmax(value), - characters[2]) - node_string += class_name - - # Clean up any trailing newlines - if node_string.endswith(characters[4]): - node_string = node_string[:-len(characters[4])] - - return node_string + characters[5] - - -class _MPLTreeExporter(_BaseTreeExporter): - def __init__(self, max_depth=None, feature_names=None, - class_names=None, label='all', filled=False, - leaves_parallel=False, impurity=True, node_ids=False, - proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3, fontsize=None): - self.max_depth = max_depth - self.feature_names = feature_names - self.class_names = class_names - self.label = label - self.filled = filled - self.leaves_parallel = leaves_parallel - self.impurity = impurity - self.node_ids = node_ids - self.proportion = proportion - self.rotate = rotate - self.rounded = rounded - self.special_characters = special_characters - self.precision = precision - self.fontsize = fontsize - self._scaley = 10 - - # validate - if isinstance(precision, Integral): - if precision < 0: - raise ValueError("'precision' should be greater or equal to 0." - " Got {} instead.".format(precision)) - else: - raise ValueError("'precision' should be an integer. Got {}" - " instead.".format(type(precision))) - - # The depth of each node for plotting with 'leaf' option - self.ranks = {'leaves': []} - # The colors to render each node with - self.colors = {'bounds': None} - - self.characters = ['#', '[', ']', '<=', '\n', '', ''] - - self.bbox_args = dict(fc='w') - if self.rounded: - self.bbox_args['boxstyle'] = "round" - self.arrow_args = dict(arrowstyle="<-") - - def _make_tree(self, node_id, et): - # traverses _tree.Tree recursively, builds intermediate - # "_reingold_tilford.Tree" object - name = self.node_to_str(et, node_id, criterion='entropy') - if (et.children_left[node_id] != et.children_right[node_id]): - children = [self._make_tree(et.children_left[node_id], et), - self._make_tree(et.children_right[node_id], et)] - else: - return Tree(name, node_id) - return Tree(name, node_id, *children) - - def export(self, decision_tree, ax=None): - import matplotlib.pyplot as plt - from matplotlib.text import Annotation - - if ax is None: - ax = plt.gca() - ax.set_axis_off() - my_tree = self._make_tree(0, decision_tree.tree_) - dt = buchheim(my_tree) - self._scalex = 1 - self.recurse(dt, decision_tree.tree_, ax) - - anns = [ann for ann in ax.get_children() - if isinstance(ann, Annotation)] - - # get all the annotated points - xys = [ann.xyann for ann in anns] - - mins = np.min(xys, axis=0) - maxs = np.max(xys, axis=0) - - ax.set_xlim(mins[0], maxs[0]) - ax.set_ylim(maxs[1], mins[1]) - - if self.fontsize is None: - # get figure to data transform - inv = ax.transData.inverted() - renderer = ax.figure.canvas.get_renderer() - # update sizes of all bboxes - for ann in anns: - ann.update_bbox_position_size(renderer) - # get max box width - widths = [inv.get_matrix()[0, 0] - * ann.get_bbox_patch().get_window_extent().width - for ann in anns] - # get minimum max size to not be too big. - max_width = max(max(widths), 1) - # adjust fontsize to avoid overlap - # width should be around 1 in data coordinates - size = anns[0].get_fontsize() / max_width - for ann in anns: - ann.set_fontsize(size) - - def recurse(self, node, tree, ax, depth=0): - kwargs = dict(bbox=self.bbox_args, ha='center', va='center', - zorder=100 - 10 * depth) - - if self.fontsize is not None: - kwargs['fontsize'] = self.fontsize - - xy = (node.x * self._scalex, node.y * self._scaley) - - if self.max_depth is None or depth <= self.max_depth: - if self.filled: - kwargs['bbox']['fc'] = self.get_fill_color(tree, - node.tree.node_id) - if node.parent is None: - # root - ax.annotate(node.tree.node, xy, **kwargs) - else: - xy_parent = (node.parent.x * self._scalex, - node.parent.y * self._scaley) - kwargs["arrowprops"] = self.arrow_args - ax.annotate(node.tree.node, xy_parent, xy, **kwargs) - for child in node.children: - self.recurse(child, tree, ax, depth=depth + 1) - - else: - xy_parent = (node.parent.x * self._scalex, node.parent.y * - self._scaley) - kwargs["arrowprops"] = self.arrow_args - kwargs['bbox']['fc'] = 'grey' - ax.annotate("\n (...) \n", xy_parent, xy, **kwargs) - - -class DrawTree(object): - def __init__(self, tree, parent=None, depth=0, number=1): - self.x = -1. - self.y = depth - self.tree = tree - self.children = [DrawTree(c, self, depth + 1, i + 1) - for i, c - in enumerate(tree.children)] - self.parent = parent - self.thread = None - self.mod = 0 - self.ancestor = self - self.change = self.shift = 0 - self._lmost_sibling = None - # this is the number of the node in its group of siblings 1..n - self.number = number - - def left(self): - return self.thread or len(self.children) and self.children[0] - - def right(self): - return self.thread or len(self.children) and self.children[-1] - - def lbrother(self): - n = None - if self.parent: - for node in self.parent.children: - if node == self: - return n - else: - n = node - return n - - def get_lmost_sibling(self): - if not self._lmost_sibling and self.parent and self != \ - self.parent.children[0]: - self._lmost_sibling = self.parent.children[0] - return self._lmost_sibling - lmost_sibling = property(get_lmost_sibling) - - def __str__(self): - return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod) - - def __repr__(self): - return self.__str__() - - -def buchheim(tree): - dt = firstwalk(DrawTree(tree)) - min = second_walk(dt) - if min < 0: - third_walk(dt, -min) - return dt - - -def third_walk(tree, n): - tree.x += n - for c in tree.children: - third_walk(c, n) - - -def firstwalk(v, distance=1.): - if len(v.children) == 0: - if v.lmost_sibling: - v.x = v.lbrother().x + distance - else: - v.x = 0. - else: - default_ancestor = v.children[0] - for w in v.children: - firstwalk(w) - default_ancestor = apportion(w, default_ancestor, distance) - # print("finished v =", v.tree, "children") - execute_shifts(v) - - midpoint = (v.children[0].x + v.children[-1].x) / 2 - - w = v.lbrother() - if w: - v.x = w.x + distance - v.mod = v.x - midpoint - else: - v.x = midpoint - return v - - -def apportion(v, default_ancestor, distance): - w = v.lbrother() - if w is not None: - # in buchheim notation: - # i == inner; o == outer; r == right; l == left; r = +; l = - - vir = vor = v - vil = w - vol = v.lmost_sibling - sir = sor = v.mod - sil = vil.mod - sol = vol.mod - while vil.right() and vir.left(): - vil = vil.right() - vir = vir.left() - vol = vol.left() - vor = vor.right() - vor.ancestor = v - shift = (vil.x + sil) - (vir.x + sir) + distance - if shift > 0: - move_subtree(ancestor(vil, v, default_ancestor), v, shift) - sir = sir + shift - sor = sor + shift - sil += vil.mod - sir += vir.mod - sol += vol.mod - sor += vor.mod - if vil.right() and not vor.right(): - vor.thread = vil.right() - vor.mod += sil - sor - else: - if vir.left() and not vol.left(): - vol.thread = vir.left() - vol.mod += sir - sol - default_ancestor = v - return default_ancestor - - -def move_subtree(wl, wr, shift): - subtrees = wr.number - wl.number - # print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees, - # 'shift', shift) - # print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees - wr.change -= shift / subtrees - wr.shift += shift - wl.change += shift / subtrees - wr.x += shift - wr.mod += shift - - -def execute_shifts(v): - shift = change = 0 - for w in v.children[::-1]: - # print("shift:", w, shift, w.change) - w.x += shift - w.mod += shift - change += w.change - shift += w.shift + change - - -def ancestor(vil, v, default_ancestor): - # the relevant text is at the bottom of page 7 of - # "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al, - # (2002) - # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.16.8757&rep=rep1&type=pdf - if vil.ancestor in v.parent.children: - return vil.ancestor - else: - return default_ancestor - - -def second_walk(v, m=0, depth=0, min=None): - v.x += m - v.y = depth - - if min is None or v.x < min: - min = v.x - - for w in v.children: - min = second_walk(w, m + v.mod, depth + 1, min) - - return min - - -class Tree(object): - def __init__(self, node="", node_id=-1, *children): - self.node = node - self.width = len(node) - self.node_id = node_id - if children: - self.children = children - else: - self.children = [] From 7d76ca9980700a72233ca58fcb700e8b607ce05b Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 28 Sep 2018 16:47:32 -0400 Subject: [PATCH 49/61] add whatsnew entry that the merge lost --- doc/whats_new/v0.21.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 03440502aecb2..e12884bbb38f0 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -48,6 +48,15 @@ Support for Python 3.4 and below has been officially dropped. to set and that scales better, by :user:`Shane ` and :user:`Adrin Jalali `. + +:mod:`sklearn.tree` +...................... + +- Decision Trees can now be plotted with matplotlib using + :func:`tree.export.plot_tree` without relying on the ``dot`` library, + removing a hard-to-install dependency. + :issue:`8508` by `Andreas Müller`_. + Multiple modules ................ From 1937f3de445e72bb1c7b2485b6833ec6797205b4 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 1 Oct 2018 16:30:22 -0400 Subject: [PATCH 50/61] fix merge issue --- doc/whats_new/v0.20.rst | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 9aeddc893136e..e35990fe40006 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -254,21 +254,10 @@ Support for Python 3.3 has been officially dropped. and will be used instead of mldata as it provides better service availability. :issue:`9908` by `Andreas Müller`_ and :user:`Jan N. van Rijn `. -<<<<<<< HEAD -- Decision Trees can now be plotted with matplotlib using - :func:`tree.export.plot_tree` without relying on the ``dot`` library, - removing a hard-to-install dependency. - :issue:`8508` by `Andreas Müller`_. - -- An environment variable to use the site joblib instead of the vendored - one was added (:ref:`environment_variable`). - :issue:`11166`by `Gael Varoquaux`_ -======= - |Feature| In :func:`datasets.make_blobs`, one can now pass a list to the ``n_samples`` parameter to indicate the number of samples to generate per cluster. :issue:`8617` by :user:`Maskani Filali Mohamed ` and :user:`Konstantinos Katrioplas `. ->>>>>>> master - |Feature| Add ``filename`` attribute to :mod:`datasets` that have a CSV file. :issue:`9101` by :user:`alex-33 ` From db464a392b3fe7f991b19447ba5c95e427172d1c Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 1 Oct 2018 16:31:36 -0400 Subject: [PATCH 51/61] more merge issues --- doc/whats_new/v0.21.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index e9f18842120d3..735a497e8a4ac 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -48,7 +48,7 @@ Support for Python 3.4 and below has been officially dropped. to set and that scales better, by :user:`Shane ` and :user:`Adrin Jalali `. - - |Fix| Fixed a bug in :class:`cluster.DBSCAN` with precomputed sparse neighbors +- |Fix| Fixed a bug in :class:`cluster.DBSCAN` with precomputed sparse neighbors graph, which would add explicitly zeros on the diagonal even when already present. :issue:`12105` by `Tom Dupre la Tour`_. From 042865a666589889e9a510f006c21157054536be Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 1 Oct 2018 16:32:36 -0400 Subject: [PATCH 52/61] whitespace ... --- doc/whats_new/v0.21.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 735a497e8a4ac..b605ecb1749b5 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -47,7 +47,7 @@ Support for Python 3.4 and below has been officially dropped. algoritm related to :class:`cluster.DBSCAN`, that has hyperparameters easier to set and that scales better, by :user:`Shane ` and :user:`Adrin Jalali `. - + - |Fix| Fixed a bug in :class:`cluster.DBSCAN` with precomputed sparse neighbors graph, which would add explicitly zeros on the diagonal even when already present. :issue:`12105` by `Tom Dupre la Tour`_. From 1a1675099f2488db2fcb029c0daada3dc843243e Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 3 Oct 2018 14:09:07 -0400 Subject: [PATCH 53/61] remove doctest skip to see what's happening --- sklearn/tree/export.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 1d062ba986a5d..f0b242e28b041 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -85,6 +85,10 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, might be present. This function requires matplotlib, and works best with matplotlib >= 1.5. + The visuaization is fit automatically to the size of the axis. + Use the ``figsize`` or ``dpi`` arguments of ``plt.figure`` to control + the size of the rendering. + Read more in the :ref:`User Guide `. .. versionadded:: 0.20 @@ -166,7 +170,7 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, >>> iris = load_iris() >>> clf = clf.fit(iris.data, iris.target) - >>> tree.plot_tree(clf) # doctest: +SKIP + >>> tree.plot_tree(clf) [Text(251.5,345.217,'X[3] <= 0.8... """ From 6f2d5975b130fe96137f435a3c70b8d43884f8dc Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 3 Oct 2018 15:18:00 -0400 Subject: [PATCH 54/61] added some simple invariance tests buchheim function --- sklearn/tree/tests/test_reingold.py | 54 +++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 sklearn/tree/tests/test_reingold.py diff --git a/sklearn/tree/tests/test_reingold.py b/sklearn/tree/tests/test_reingold.py new file mode 100644 index 0000000000000..4cb27ce6effb9 --- /dev/null +++ b/sklearn/tree/tests/test_reingold.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest +from sklearn.tree._reingold_tilford import buchheim, Tree + +simple_tree = Tree("", 0, + Tree("", 1), + Tree("", 2)) + +bigger_tree = Tree("", 0, + Tree("", 1, + Tree("", 3), + Tree("", 4, + Tree("", 7), + Tree("", 8) + ), + ), + Tree("", 2, + Tree("", 5), + Tree("", 6) + ) + ) + + +@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)]) +def test_buchheim(tree, n_nodes): + def walk_tree(draw_tree): + res = [(draw_tree.x, draw_tree.y)] + for child in draw_tree.children: + # parents higher than children: + assert child.y == draw_tree.y + 1 + res.extend(walk_tree(child)) + if len(draw_tree.children): + # these trees are always binary + # parents are centered above children + assert draw_tree.x == (draw_tree.children[0].x + + draw_tree.children[1].x) / 2 + return res + + layout = buchheim(tree) + coordinates = walk_tree(layout) + assert len(coordinates) == n_nodes + # test that x values are unique per depth / level + # we could also do it quicker using defaultdicts.. + depth = 0 + while True: + x_at_this_depth = [] + for node in coordinates: + if coordinates[1] == depth: + x_at_this_depth.append(coordinates[0]) + if not x_at_this_depth: + # reached all leafs + break + assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth) + depth += 1 From 6803f96129e3271a7d486d4274f0a4ed4bd3760d Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 3 Oct 2018 15:24:32 -0400 Subject: [PATCH 55/61] refactor ___init__ into superclass --- sklearn/tree/export.py | 62 ++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index f0b242e28b041..8b955ed229d09 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -185,6 +185,26 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, class _BaseTreeExporter(object): + def __init__(self, max_depth=None, feature_names=None, + class_names=None, label='all', filled=False, + leaves_parallel=False, impurity=True, node_ids=False, + proportion=False, rotate=False, rounded=False, + special_characters=False, precision=3, fontsize=None): + self.max_depth = max_depth + self.feature_names = feature_names + self.class_names = class_names + self.label = label + self.filled = filled + self.leaves_parallel = leaves_parallel + self.impurity = impurity + self.node_ids = node_ids + self.proportion = proportion + self.rotate = rotate + self.rounded = rounded + self.special_characters = special_characters + self.precision = precision + self.fontsize = fontsize + def get_color(self, value): # Find the appropriate color & intensity for a node if self.colors['bounds'] is None: @@ -344,20 +364,15 @@ def __init__(self, out_file=SENTINEL, max_depth=None, filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, special_characters=False, precision=3): + + super(_DOTTreeExporter, self).__init__( + max_depth=max_depth, feature_names=feature_names, + class_names=class_names, label=label, filled=filled, + leaves_parallel=leaves_parallel, impurity=impurity, + node_ids=node_ids, proportion=proportion, rotate=rotate, + rounded=rounded, special_characters=special_characters, + precision=precision) self.out_file = out_file - self.max_depth = max_depth - self.feature_names = feature_names - self.class_names = class_names - self.label = label - self.filled = filled - self.leaves_parallel = leaves_parallel - self.impurity = impurity - self.node_ids = node_ids - self.proportion = proportion - self.rotate = rotate - self.rounded = rounded - self.special_characters = special_characters - self.precision = precision # PostScript compatibility for special characters if special_characters: @@ -505,19 +520,14 @@ def __init__(self, max_depth=None, feature_names=None, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, special_characters=False, precision=3, fontsize=None): - self.max_depth = max_depth - self.feature_names = feature_names - self.class_names = class_names - self.label = label - self.filled = filled - self.leaves_parallel = leaves_parallel - self.impurity = impurity - self.node_ids = node_ids - self.proportion = proportion - self.rotate = rotate - self.rounded = rounded - self.special_characters = special_characters - self.precision = precision + + super(_MPLTreeExporter, self).__init__( + max_depth=max_depth, feature_names=feature_names, + class_names=class_names, label=label, filled=filled, + leaves_parallel=leaves_parallel, impurity=impurity, + node_ids=node_ids, proportion=proportion, rotate=rotate, + rounded=rounded, special_characters=special_characters, + precision=precision) self.fontsize = fontsize # validate From 7b623169a5fc879197cddba44b255110b1755aa6 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 3 Oct 2018 15:37:23 -0400 Subject: [PATCH 56/61] added some tests of plot_tree --- sklearn/tree/tests/test_export.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 430483e6d41cf..2471914fa44ce 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -1,6 +1,7 @@ """ Testing for export functions of decision trees (sklearn.tree.export). """ +import pytest from re import finditer, search @@ -9,7 +10,7 @@ from sklearn.base import is_classifier from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.ensemble import GradientBoostingClassifier -from sklearn.tree import export_graphviz +from sklearn.tree import export_graphviz, plot_tree from sklearn.externals.six import StringIO from sklearn.utils.testing import (assert_in, assert_equal, assert_raises, assert_less_equal, assert_raises_regex, @@ -308,3 +309,23 @@ def test_precision(): for finding in finditer(r"<= \d+\.\d+", dot_data): assert_equal(len(search(r"\.\d+", finding.group()).group()), precision + 1) + + +def test_plot_tree(): + # mostly smoke tests + pytest.importorskip("matplotlib.pyplot") + # Check correctness of export_graphviz + clf = DecisionTreeClassifier(max_depth=3, + min_samples_split=2, + criterion="gini", + random_state=2) + clf.fit(X, y) + + # Test export code + feature_names = ['first feat', 'sepal_width'] + nodes = plot_tree(clf, feature_names=feature_names) + assert len(nodes) == 3 + assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 0.5\n" + "samples = 6\nvalue = [3, 3]") + assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]" + assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]" From 817b2bbcdf0d29cd9e0f00708e1f375ba77cfa3b Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 3 Oct 2018 15:51:27 -0400 Subject: [PATCH 57/61] put skip back in, fix typo, fix versionadded number --- sklearn/tree/export.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 8b955ed229d09..5c1351a424800 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -85,13 +85,13 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, might be present. This function requires matplotlib, and works best with matplotlib >= 1.5. - The visuaization is fit automatically to the size of the axis. + The visualization is fit automatically to the size of the axis. Use the ``figsize`` or ``dpi`` arguments of ``plt.figure`` to control the size of the rendering. Read more in the :ref:`User Guide `. - .. versionadded:: 0.20 + .. versionadded:: 0.21 Parameters ---------- @@ -170,7 +170,7 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, >>> iris = load_iris() >>> clf = clf.fit(iris.data, iris.target) - >>> tree.plot_tree(clf) + >>> tree.plot_tree(clf) # doctest: +SKIP [Text(251.5,345.217,'X[3] <= 0.8... """ From 55c7d3607d6fad1a736912885698dc44b12bc57e Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 3 Oct 2018 15:58:10 -0400 Subject: [PATCH 58/61] remove unused parameters special_characters and parallel_leaves from mpl plotting --- sklearn/tree/export.py | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 5c1351a424800..fe127d77302b6 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -76,9 +76,9 @@ def __repr__(self): def plot_tree(decision_tree, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, - leaves_parallel=False, impurity=True, node_ids=False, + impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3, ax=None, fontsize=None): + precision=3, ax=None, fontsize=None): """Plot a decision tree. The sample counts that are shown are weighted with any sample_weights that @@ -120,9 +120,6 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, classification, extremity of values for regression, or purity of node for multi-output. - leaves_parallel : bool, optional (default=False) - When set to ``True``, draw all leaf nodes at the bottom of the tree. - impurity : bool, optional (default=True) When set to ``True``, show the impurity at each node. @@ -140,10 +137,6 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, When set to ``True``, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman. - special_characters : bool, optional (default=False) - When set to ``False``, ignore special characters for PostScript - compatibility. - precision : int, optional (default=3) Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node. @@ -177,31 +170,28 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None, exporter = _MPLTreeExporter( max_depth=max_depth, feature_names=feature_names, class_names=class_names, label=label, filled=filled, - leaves_parallel=leaves_parallel, impurity=impurity, node_ids=node_ids, + impurity=impurity, node_ids=node_ids, proportion=proportion, rotate=rotate, rounded=rounded, - special_characters=special_characters, precision=precision, - fontsize=fontsize) + precision=precision, fontsize=fontsize) return exporter.export(decision_tree, ax=ax) class _BaseTreeExporter(object): def __init__(self, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, - leaves_parallel=False, impurity=True, node_ids=False, + impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3, fontsize=None): + precision=3, fontsize=None): self.max_depth = max_depth self.feature_names = feature_names self.class_names = class_names self.label = label self.filled = filled - self.leaves_parallel = leaves_parallel self.impurity = impurity self.node_ids = node_ids self.proportion = proportion self.rotate = rotate self.rounded = rounded - self.special_characters = special_characters self.precision = precision self.fontsize = fontsize @@ -368,11 +358,13 @@ def __init__(self, out_file=SENTINEL, max_depth=None, super(_DOTTreeExporter, self).__init__( max_depth=max_depth, feature_names=feature_names, class_names=class_names, label=label, filled=filled, - leaves_parallel=leaves_parallel, impurity=impurity, + impurity=impurity, node_ids=node_ids, proportion=proportion, rotate=rotate, - rounded=rounded, special_characters=special_characters, + rounded=rounded, precision=precision) + self.leaves_parallel = leaves_parallel self.out_file = out_file + self.special_characters = special_characters # PostScript compatibility for special characters if special_characters: @@ -517,17 +509,15 @@ def recurse(self, tree, node_id, criterion, parent=None, depth=0): class _MPLTreeExporter(_BaseTreeExporter): def __init__(self, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, - leaves_parallel=False, impurity=True, node_ids=False, + impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3, fontsize=None): + precision=3, fontsize=None): super(_MPLTreeExporter, self).__init__( max_depth=max_depth, feature_names=feature_names, class_names=class_names, label=label, filled=filled, - leaves_parallel=leaves_parallel, impurity=impurity, - node_ids=node_ids, proportion=proportion, rotate=rotate, - rounded=rounded, special_characters=special_characters, - precision=precision) + impurity=impurity, node_ids=node_ids, proportion=proportion, + rotate=rotate, rounded=rounded, precision=precision) self.fontsize = fontsize # validate From 9554a87ac2fe8c05038eecc8cf1e9e2f975b643e Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 10 Oct 2018 15:41:50 -0400 Subject: [PATCH 59/61] rename tests to test_reingold_tilford --- sklearn/tree/tests/test_reingold_tilford.py | 54 +++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 sklearn/tree/tests/test_reingold_tilford.py diff --git a/sklearn/tree/tests/test_reingold_tilford.py b/sklearn/tree/tests/test_reingold_tilford.py new file mode 100644 index 0000000000000..4cb27ce6effb9 --- /dev/null +++ b/sklearn/tree/tests/test_reingold_tilford.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest +from sklearn.tree._reingold_tilford import buchheim, Tree + +simple_tree = Tree("", 0, + Tree("", 1), + Tree("", 2)) + +bigger_tree = Tree("", 0, + Tree("", 1, + Tree("", 3), + Tree("", 4, + Tree("", 7), + Tree("", 8) + ), + ), + Tree("", 2, + Tree("", 5), + Tree("", 6) + ) + ) + + +@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)]) +def test_buchheim(tree, n_nodes): + def walk_tree(draw_tree): + res = [(draw_tree.x, draw_tree.y)] + for child in draw_tree.children: + # parents higher than children: + assert child.y == draw_tree.y + 1 + res.extend(walk_tree(child)) + if len(draw_tree.children): + # these trees are always binary + # parents are centered above children + assert draw_tree.x == (draw_tree.children[0].x + + draw_tree.children[1].x) / 2 + return res + + layout = buchheim(tree) + coordinates = walk_tree(layout) + assert len(coordinates) == n_nodes + # test that x values are unique per depth / level + # we could also do it quicker using defaultdicts.. + depth = 0 + while True: + x_at_this_depth = [] + for node in coordinates: + if coordinates[1] == depth: + x_at_this_depth.append(coordinates[0]) + if not x_at_this_depth: + # reached all leafs + break + assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth) + depth += 1 From 435d217358508f5c3edbf5a377653ed7d6070fd3 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Wed, 10 Oct 2018 15:45:10 -0400 Subject: [PATCH 60/61] added license header from pymag-trees repo --- sklearn/tree/_reingold_tilford.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sklearn/tree/_reingold_tilford.py b/sklearn/tree/_reingold_tilford.py index 0ca79d89887b3..d83969badb623 100644 --- a/sklearn/tree/_reingold_tilford.py +++ b/sklearn/tree/_reingold_tilford.py @@ -1,4 +1,20 @@ # taken from https://github.com/llimllib/pymag-trees/blob/master/buchheim.py +# with slight modifications + +# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE +# Version 2, December 2004 +# +# Copyright (C) 2004 Sam Hocevar +# +# Everyone is permitted to copy and distribute verbatim or modified +# copies of this license document, and changing it is allowed as long +# as the name is changed. +# +# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE +# TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + +# 0. You just DO WHAT THE FUCK YOU WANT TO. + import numpy as np From 98e8d5ab51e83ac466d67714313b984c2e85f51b Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Thu, 11 Oct 2018 15:41:48 -0400 Subject: [PATCH 61/61] remove duplicate test file. --- sklearn/tree/tests/test_reingold.py | 54 ----------------------------- 1 file changed, 54 deletions(-) delete mode 100644 sklearn/tree/tests/test_reingold.py diff --git a/sklearn/tree/tests/test_reingold.py b/sklearn/tree/tests/test_reingold.py deleted file mode 100644 index 4cb27ce6effb9..0000000000000 --- a/sklearn/tree/tests/test_reingold.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np -import pytest -from sklearn.tree._reingold_tilford import buchheim, Tree - -simple_tree = Tree("", 0, - Tree("", 1), - Tree("", 2)) - -bigger_tree = Tree("", 0, - Tree("", 1, - Tree("", 3), - Tree("", 4, - Tree("", 7), - Tree("", 8) - ), - ), - Tree("", 2, - Tree("", 5), - Tree("", 6) - ) - ) - - -@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)]) -def test_buchheim(tree, n_nodes): - def walk_tree(draw_tree): - res = [(draw_tree.x, draw_tree.y)] - for child in draw_tree.children: - # parents higher than children: - assert child.y == draw_tree.y + 1 - res.extend(walk_tree(child)) - if len(draw_tree.children): - # these trees are always binary - # parents are centered above children - assert draw_tree.x == (draw_tree.children[0].x - + draw_tree.children[1].x) / 2 - return res - - layout = buchheim(tree) - coordinates = walk_tree(layout) - assert len(coordinates) == n_nodes - # test that x values are unique per depth / level - # we could also do it quicker using defaultdicts.. - depth = 0 - while True: - x_at_this_depth = [] - for node in coordinates: - if coordinates[1] == depth: - x_at_this_depth.append(coordinates[0]) - if not x_at_this_depth: - # reached all leafs - break - assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth) - depth += 1