diff --git a/doc/Makefile b/doc/Makefile index fcb547d14e2b0..6629518fc556a 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -13,7 +13,7 @@ endif # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS)\ +ALLSPHINXOPTS = -T -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS)\ $(EXAMPLES_PATTERN_OPTS) . diff --git a/doc/modules/tree.rst b/doc/modules/tree.rst index 4c3f584b079ab..fe5bed4c0221f 100644 --- a/doc/modules/tree.rst +++ b/doc/modules/tree.rst @@ -124,9 +124,20 @@ Using the Iris dataset, we can construct a tree as follows:: >>> clf = tree.DecisionTreeClassifier() >>> clf = clf.fit(iris.data, iris.target) -Once trained, we can export the tree in `Graphviz +Once trained, you can plot the tree with the plot_tree function:: + + + >>> tree.plot_tree(clf.fit(iris.data, iris.target)) # doctest: +SKIP + +.. figure:: ../auto_examples/tree/images/sphx_glr_plot_iris_002.png + :target: ../auto_examples/tree/plot_iris.html + :scale: 75 + :align: center + +We can also export the tree in `Graphviz `_ format using the :func:`export_graphviz` -exporter. If you use the `conda `_ package manager, the graphviz binaries +exporter. If you use the `conda `_ package manager, the graphviz binaries + and the python package can be installed with conda install python-graphviz diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index be200250a9ae5..c2312875b1c68 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -62,6 +62,11 @@ Support for Python 3.4 and below has been officially dropped. :mod:`sklearn.tree` ................... +- Decision Trees can now be plotted with matplotlib using + :func:`tree.export.plot_tree` without relying on the ``dot`` library, + removing a hard-to-install dependency. + :issue:`8508` by `Andreas Müller`_. + - |Feature| ``get_n_leaves()`` and ``get_depth()`` have been added to :class:`tree.BaseDecisionTree` and consequently all estimators based on it, including :class:`tree.DecisionTreeClassifier`, diff --git a/examples/tree/plot_iris.py b/examples/tree/plot_iris.py index f299aab18d7d1..60328c4f90d4f 100644 --- a/examples/tree/plot_iris.py +++ b/examples/tree/plot_iris.py @@ -11,6 +11,8 @@ For each pair of iris features, the decision tree learns decision boundaries made of combinations of simple thresholding rules inferred from the training samples. + +We also show the tree structure of a model built on all of the features. """ print(__doc__) @@ -18,7 +20,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import load_iris -from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeClassifier, plot_tree # Parameters n_classes = 3 @@ -62,4 +64,8 @@ plt.suptitle("Decision surface of a decision tree using paired features") plt.legend(loc='lower right', borderpad=0, handletextpad=0) plt.axis("tight") + +plt.figure() +clf = DecisionTreeClassifier().fit(iris.data, iris.target) +plot_tree(clf, filled=True) plt.show() diff --git a/sklearn/tree/__init__.py b/sklearn/tree/__init__.py index 1394bd914d27c..b3abe30d019fa 100644 --- a/sklearn/tree/__init__.py +++ b/sklearn/tree/__init__.py @@ -7,7 +7,8 @@ from .tree import DecisionTreeRegressor from .tree import ExtraTreeClassifier from .tree import ExtraTreeRegressor -from .export import export_graphviz +from .export import export_graphviz, plot_tree __all__ = ["DecisionTreeClassifier", "DecisionTreeRegressor", - "ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz"] + "ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz", + "plot_tree"] diff --git a/sklearn/tree/_reingold_tilford.py b/sklearn/tree/_reingold_tilford.py new file mode 100644 index 0000000000000..d83969badb623 --- /dev/null +++ b/sklearn/tree/_reingold_tilford.py @@ -0,0 +1,203 @@ +# taken from https://github.com/llimllib/pymag-trees/blob/master/buchheim.py +# with slight modifications + +# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE +# Version 2, December 2004 +# +# Copyright (C) 2004 Sam Hocevar +# +# Everyone is permitted to copy and distribute verbatim or modified +# copies of this license document, and changing it is allowed as long +# as the name is changed. +# +# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE +# TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + +# 0. You just DO WHAT THE FUCK YOU WANT TO. + + +import numpy as np + + +class DrawTree(object): + def __init__(self, tree, parent=None, depth=0, number=1): + self.x = -1. + self.y = depth + self.tree = tree + self.children = [DrawTree(c, self, depth + 1, i + 1) + for i, c + in enumerate(tree.children)] + self.parent = parent + self.thread = None + self.mod = 0 + self.ancestor = self + self.change = self.shift = 0 + self._lmost_sibling = None + # this is the number of the node in its group of siblings 1..n + self.number = number + + def left(self): + return self.thread or len(self.children) and self.children[0] + + def right(self): + return self.thread or len(self.children) and self.children[-1] + + def lbrother(self): + n = None + if self.parent: + for node in self.parent.children: + if node == self: + return n + else: + n = node + return n + + def get_lmost_sibling(self): + if not self._lmost_sibling and self.parent and self != \ + self.parent.children[0]: + self._lmost_sibling = self.parent.children[0] + return self._lmost_sibling + lmost_sibling = property(get_lmost_sibling) + + def __str__(self): + return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod) + + def __repr__(self): + return self.__str__() + + def max_extents(self): + extents = [c.max_extents() for c in self. children] + extents.append((self.x, self.y)) + return np.max(extents, axis=0) + + +def buchheim(tree): + dt = first_walk(DrawTree(tree)) + min = second_walk(dt) + if min < 0: + third_walk(dt, -min) + return dt + + +def third_walk(tree, n): + tree.x += n + for c in tree.children: + third_walk(c, n) + + +def first_walk(v, distance=1.): + if len(v.children) == 0: + if v.lmost_sibling: + v.x = v.lbrother().x + distance + else: + v.x = 0. + else: + default_ancestor = v.children[0] + for w in v.children: + first_walk(w) + default_ancestor = apportion(w, default_ancestor, distance) + # print("finished v =", v.tree, "children") + execute_shifts(v) + + midpoint = (v.children[0].x + v.children[-1].x) / 2 + + w = v.lbrother() + if w: + v.x = w.x + distance + v.mod = v.x - midpoint + else: + v.x = midpoint + return v + + +def apportion(v, default_ancestor, distance): + w = v.lbrother() + if w is not None: + # in buchheim notation: + # i == inner; o == outer; r == right; l == left; r = +; l = - + vir = vor = v + vil = w + vol = v.lmost_sibling + sir = sor = v.mod + sil = vil.mod + sol = vol.mod + while vil.right() and vir.left(): + vil = vil.right() + vir = vir.left() + vol = vol.left() + vor = vor.right() + vor.ancestor = v + shift = (vil.x + sil) - (vir.x + sir) + distance + if shift > 0: + move_subtree(ancestor(vil, v, default_ancestor), v, shift) + sir = sir + shift + sor = sor + shift + sil += vil.mod + sir += vir.mod + sol += vol.mod + sor += vor.mod + if vil.right() and not vor.right(): + vor.thread = vil.right() + vor.mod += sil - sor + else: + if vir.left() and not vol.left(): + vol.thread = vir.left() + vol.mod += sir - sol + default_ancestor = v + return default_ancestor + + +def move_subtree(wl, wr, shift): + subtrees = wr.number - wl.number + # print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees, + # 'shift', shift) + # print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees + wr.change -= shift / subtrees + wr.shift += shift + wl.change += shift / subtrees + wr.x += shift + wr.mod += shift + + +def execute_shifts(v): + shift = change = 0 + for w in v.children[::-1]: + # print("shift:", w, shift, w.change) + w.x += shift + w.mod += shift + change += w.change + shift += w.shift + change + + +def ancestor(vil, v, default_ancestor): + # the relevant text is at the bottom of page 7 of + # "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al, + # (2002) + # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.16.8757&rep=rep1&type=pdf + if vil.ancestor in v.parent.children: + return vil.ancestor + else: + return default_ancestor + + +def second_walk(v, m=0, depth=0, min=None): + v.x += m + v.y = depth + + if min is None or v.x < min: + min = v.x + + for w in v.children: + min = second_walk(w, m + v.mod, depth + 1, min) + + return min + + +class Tree(object): + def __init__(self, label="", node_id=-1, *children): + self.label = label + self.node_id = node_id + if children: + self.children = children + else: + self.children = [] diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index ef13790e65b42..fe127d77302b6 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -10,6 +10,7 @@ # Trevor Stephens # Li Li # License: BSD 3 clause +import warnings from numbers import Integral @@ -20,6 +21,7 @@ from . import _criterion from . import _tree +from ._reingold_tilford import buchheim, Tree def _color_brew(n): @@ -72,37 +74,30 @@ def __repr__(self): SENTINEL = Sentinel() -def export_graphviz(decision_tree, out_file=None, max_depth=None, - feature_names=None, class_names=None, label='all', - filled=False, leaves_parallel=False, impurity=True, - node_ids=False, proportion=False, rotate=False, - rounded=False, special_characters=False, precision=3): - """Export a decision tree in DOT format. - - This function generates a GraphViz representation of the decision tree, - which is then written into `out_file`. Once exported, graphical renderings - can be generated using, for example:: - - $ dot -Tps tree.dot -o tree.ps (PostScript format) - $ dot -Tpng tree.dot -o tree.png (PNG format) +def plot_tree(decision_tree, max_depth=None, feature_names=None, + class_names=None, label='all', filled=False, + impurity=True, node_ids=False, + proportion=False, rotate=False, rounded=False, + precision=3, ax=None, fontsize=None): + """Plot a decision tree. The sample counts that are shown are weighted with any sample_weights that might be present. + This function requires matplotlib, and works best with matplotlib >= 1.5. + + The visualization is fit automatically to the size of the axis. + Use the ``figsize`` or ``dpi`` arguments of ``plt.figure`` to control + the size of the rendering. Read more in the :ref:`User Guide `. + .. versionadded:: 0.21 + Parameters ---------- decision_tree : decision tree regressor or classifier The decision tree to be exported to GraphViz. - out_file : file object or string, optional (default=None) - Handle or name of the output file. If ``None``, the result is - returned as a string. - - .. versionchanged:: 0.20 - Default of out_file changed from "tree.dot" to None. - max_depth : int, optional (default=None) The maximum depth of the representation. If None, the tree is fully generated. @@ -125,9 +120,6 @@ def export_graphviz(decision_tree, out_file=None, max_depth=None, classification, extremity of values for regression, or purity of node for multi-output. - leaves_parallel : bool, optional (default=False) - When set to ``True``, draw all leaf nodes at the bottom of the tree. - impurity : bool, optional (default=True) When set to ``True``, show the impurity at each node. @@ -145,64 +137,113 @@ def export_graphviz(decision_tree, out_file=None, max_depth=None, When set to ``True``, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman. - special_characters : bool, optional (default=False) - When set to ``False``, ignore special characters for PostScript - compatibility. - precision : int, optional (default=3) Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node. + ax : matplotlib axis, optional (default=None) + Axes to plot to. If None, use current axis. Any previous content + is cleared. + + fontsize : int, optional (default=None) + Size of text font. If None, determined automatically to fit figure. + Returns ------- - dot_data : string - String representation of the input tree in GraphViz dot format. - Only returned if ``out_file`` is None. - - .. versionadded:: 0.18 + annotations : list of artists + List containing the artists for the annotation boxes making up the + tree. Examples -------- >>> from sklearn.datasets import load_iris >>> from sklearn import tree - >>> clf = tree.DecisionTreeClassifier() + >>> clf = tree.DecisionTreeClassifier(random_state=0) >>> iris = load_iris() >>> clf = clf.fit(iris.data, iris.target) - >>> tree.export_graphviz(clf, - ... out_file='tree.dot') # doctest: +SKIP + >>> tree.plot_tree(clf) # doctest: +SKIP + [Text(251.5,345.217,'X[3] <= 0.8... """ - - def get_color(value): + exporter = _MPLTreeExporter( + max_depth=max_depth, feature_names=feature_names, + class_names=class_names, label=label, filled=filled, + impurity=impurity, node_ids=node_ids, + proportion=proportion, rotate=rotate, rounded=rounded, + precision=precision, fontsize=fontsize) + return exporter.export(decision_tree, ax=ax) + + +class _BaseTreeExporter(object): + def __init__(self, max_depth=None, feature_names=None, + class_names=None, label='all', filled=False, + impurity=True, node_ids=False, + proportion=False, rotate=False, rounded=False, + precision=3, fontsize=None): + self.max_depth = max_depth + self.feature_names = feature_names + self.class_names = class_names + self.label = label + self.filled = filled + self.impurity = impurity + self.node_ids = node_ids + self.proportion = proportion + self.rotate = rotate + self.rounded = rounded + self.precision = precision + self.fontsize = fontsize + + def get_color(self, value): # Find the appropriate color & intensity for a node - if colors['bounds'] is None: + if self.colors['bounds'] is None: # Classification tree - color = list(colors['rgb'][np.argmax(value)]) + color = list(self.colors['rgb'][np.argmax(value)]) sorted_values = sorted(value, reverse=True) if len(sorted_values) == 1: alpha = 0 else: - alpha = int(np.round(255 * (sorted_values[0] - - sorted_values[1]) / - (1 - sorted_values[1]), 0)) + alpha = ((sorted_values[0] - sorted_values[1]) + / (1 - sorted_values[1])) else: # Regression tree or multi-output - color = list(colors['rgb'][0]) - alpha = int(np.round(255 * ((value - colors['bounds'][0]) / - (colors['bounds'][1] - - colors['bounds'][0])), 0)) - - # Return html color code in #RRGGBBAA format - color.append(alpha) - hex_codes = [str(i) for i in range(10)] - hex_codes.extend(['a', 'b', 'c', 'd', 'e', 'f']) - color = [hex_codes[c // 16] + hex_codes[c % 16] for c in color] - - return '#' + ''.join(color) + color = list(self.colors['rgb'][0]) + alpha = ((value - self.colors['bounds'][0]) / + (self.colors['bounds'][1] - self.colors['bounds'][0])) + # unpack numpy scalars + alpha = float(alpha) + # compute the color as alpha against white + color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color] + # Return html color code in #RRGGBB format + return '#%2x%2x%2x' % tuple(color) + + def get_fill_color(self, tree, node_id): + # Fetch appropriate color for node + if 'rgb' not in self.colors: + # Initialize colors and bounds if required + self.colors['rgb'] = _color_brew(tree.n_classes[0]) + if tree.n_outputs != 1: + # Find max and min impurities for multi-output + self.colors['bounds'] = (np.min(-tree.impurity), + np.max(-tree.impurity)) + elif (tree.n_classes[0] == 1 and + len(np.unique(tree.value)) != 1): + # Find max and min values in leaf nodes for regression + self.colors['bounds'] = (np.min(tree.value), + np.max(tree.value)) + if tree.n_outputs == 1: + node_val = (tree.value[node_id][0, :] / + tree.weighted_n_node_samples[node_id]) + if tree.n_classes[0] == 1: + # Regression + node_val = tree.value[node_id][0, :] + else: + # If multi-output color node by impurity + node_val = -tree.impurity[node_id] + return self.get_color(node_val) - def node_to_str(tree, node_id, criterion): + def node_to_str(self, tree, node_id, criterion): # Generate the node content string if tree.n_outputs == 1: value = tree.value[node_id][0, :] @@ -210,18 +251,13 @@ def node_to_str(tree, node_id, criterion): value = tree.value[node_id] # Should labels be shown? - labels = (label == 'root' and node_id == 0) or label == 'all' + labels = (self.label == 'root' and node_id == 0) or self.label == 'all' - # PostScript compatibility for special characters - if special_characters: - characters = ['#', '', '', '≤', '
', '>'] - node_string = '<' - else: - characters = ['#', '[', ']', '<=', '\\n', '"'] - node_string = '"' + characters = self.characters + node_string = characters[-1] # Write node ID - if node_ids: + if self.node_ids: if labels: node_string += 'node ' node_string += characters[0] + str(node_id) + characters[4] @@ -229,8 +265,8 @@ def node_to_str(tree, node_id, criterion): # Write decision criteria if tree.children_left[node_id] != _tree.TREE_LEAF: # Always write node decision criteria, except for leaves - if feature_names is not None: - feature = feature_names[tree.feature[node_id]] + if self.feature_names is not None: + feature = self.feature_names[tree.feature[node_id]] else: feature = "X%s%s%s" % (characters[1], tree.feature[node_id], @@ -238,24 +274,24 @@ def node_to_str(tree, node_id, criterion): node_string += '%s %s %s%s' % (feature, characters[3], round(tree.threshold[node_id], - precision), + self.precision), characters[4]) # Write impurity - if impurity: + if self.impurity: if isinstance(criterion, _criterion.FriedmanMSE): criterion = "friedman_mse" elif not isinstance(criterion, six.string_types): criterion = "impurity" if labels: node_string += '%s = ' % criterion - node_string += (str(round(tree.impurity[node_id], precision)) + - characters[4]) + node_string += (str(round(tree.impurity[node_id], self.precision)) + + characters[4]) # Write node sample count if labels: node_string += 'samples = ' - if proportion: + if self.proportion: percent = (100. * tree.n_node_samples[node_id] / float(tree.n_node_samples[0])) node_string += (str(round(percent, 1)) + '%' + @@ -265,23 +301,23 @@ def node_to_str(tree, node_id, criterion): characters[4]) # Write node class distribution / regression value - if proportion and tree.n_classes[0] != 1: + if self.proportion and tree.n_classes[0] != 1: # For classification this will show the proportion of samples value = value / tree.weighted_n_node_samples[node_id] if labels: node_string += 'value = ' if tree.n_classes[0] == 1: # Regression - value_text = np.around(value, precision) - elif proportion: + value_text = np.around(value, self.precision) + elif self.proportion: # Classification - value_text = np.around(value, precision) + value_text = np.around(value, self.precision) elif np.all(np.equal(np.mod(value, 1), 0)): # Classification without floating-point weights value_text = value.astype(int) else: # Classification with floating-point weights - value_text = np.around(value, precision) + value_text = np.around(value, self.precision) # Strip whitespace value_text = str(value_text.astype('S32')).replace("b'", "'") value_text = value_text.replace("' '", ", ").replace("'", "") @@ -291,14 +327,14 @@ def node_to_str(tree, node_id, criterion): node_string += value_text + characters[4] # Write node majority class - if (class_names is not None and + if (self.class_names is not None and tree.n_classes[0] != 1 and tree.n_outputs == 1): # Only done for single-output classification trees if labels: node_string += 'class = ' - if class_names is not True: - class_name = class_names[np.argmax(value)] + if self.class_names is not True: + class_name = self.class_names[np.argmax(value)] else: class_name = "y%s%s%s" % (characters[1], np.argmax(value), @@ -306,14 +342,109 @@ def node_to_str(tree, node_id, criterion): node_string += class_name # Clean up any trailing newlines - if node_string[-2:] == '\\n': - node_string = node_string[:-2] - if node_string[-5:] == '
': - node_string = node_string[:-5] + if node_string.endswith(characters[4]): + node_string = node_string[:-len(characters[4])] return node_string + characters[5] - def recurse(tree, node_id, criterion, parent=None, depth=0): + +class _DOTTreeExporter(_BaseTreeExporter): + def __init__(self, out_file=SENTINEL, max_depth=None, + feature_names=None, class_names=None, label='all', + filled=False, leaves_parallel=False, impurity=True, + node_ids=False, proportion=False, rotate=False, rounded=False, + special_characters=False, precision=3): + + super(_DOTTreeExporter, self).__init__( + max_depth=max_depth, feature_names=feature_names, + class_names=class_names, label=label, filled=filled, + impurity=impurity, + node_ids=node_ids, proportion=proportion, rotate=rotate, + rounded=rounded, + precision=precision) + self.leaves_parallel = leaves_parallel + self.out_file = out_file + self.special_characters = special_characters + + # PostScript compatibility for special characters + if special_characters: + self.characters = ['#', '', '', '≤', '
', + '>', '<'] + else: + self.characters = ['#', '[', ']', '<=', '\\n', '"', '"'] + + # validate + if isinstance(precision, Integral): + if precision < 0: + raise ValueError("'precision' should be greater or equal to 0." + " Got {} instead.".format(precision)) + else: + raise ValueError("'precision' should be an integer. Got {}" + " instead.".format(type(precision))) + + # The depth of each node for plotting with 'leaf' option + self.ranks = {'leaves': []} + # The colors to render each node with + self.colors = {'bounds': None} + + def export(self, decision_tree): + # Check length of feature_names before getting into the tree node + # Raise error if length of feature_names does not match + # n_features_ in the decision_tree + if self.feature_names is not None: + if len(self.feature_names) != decision_tree.n_features_: + raise ValueError("Length of feature_names, %d " + "does not match number of features, %d" + % (len(self.feature_names), + decision_tree.n_features_)) + # each part writes to out_file + self.head() + # Now recurse the tree and add node & edge attributes + if isinstance(decision_tree, _tree.Tree): + self.recurse(decision_tree, 0, criterion="impurity") + else: + self.recurse(decision_tree.tree_, 0, + criterion=decision_tree.criterion) + + self.tail() + + def tail(self): + # If required, draw leaf nodes at same depth as each other + if self.leaves_parallel: + for rank in sorted(self.ranks): + self.out_file.write( + "{rank=same ; " + + "; ".join(r for r in self.ranks[rank]) + "} ;\n") + self.out_file.write("}") + + def head(self): + self.out_file.write('digraph Tree {\n') + + # Specify node aesthetics + self.out_file.write('node [shape=box') + rounded_filled = [] + if self.filled: + rounded_filled.append('filled') + if self.rounded: + rounded_filled.append('rounded') + if len(rounded_filled) > 0: + self.out_file.write( + ', style="%s", color="black"' + % ", ".join(rounded_filled)) + if self.rounded: + self.out_file.write(', fontname=helvetica') + self.out_file.write('] ;\n') + + # Specify graph & edge aesthetics + if self.leaves_parallel: + self.out_file.write( + 'graph [ranksep=equally, splines=polyline] ;\n') + if self.rounded: + self.out_file.write('edge [fontname=helvetica] ;\n') + if self.rotate: + self.out_file.write('rankdir=LR ;\n') + + def recurse(self, tree, node_id, criterion, parent=None, depth=0): if node_id == _tree.TREE_LEAF: raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF) @@ -321,93 +452,75 @@ def recurse(tree, node_id, criterion, parent=None, depth=0): right_child = tree.children_right[node_id] # Add node with description - if max_depth is None or depth <= max_depth: + if self.max_depth is None or depth <= self.max_depth: # Collect ranks for 'leaf' option in plot_options if left_child == _tree.TREE_LEAF: - ranks['leaves'].append(str(node_id)) - elif str(depth) not in ranks: - ranks[str(depth)] = [str(node_id)] + self.ranks['leaves'].append(str(node_id)) + elif str(depth) not in self.ranks: + self.ranks[str(depth)] = [str(node_id)] else: - ranks[str(depth)].append(str(node_id)) - - out_file.write('%d [label=%s' - % (node_id, - node_to_str(tree, node_id, criterion))) - - if filled: - # Fetch appropriate color for node - if 'rgb' not in colors: - # Initialize colors and bounds if required - colors['rgb'] = _color_brew(tree.n_classes[0]) - if tree.n_outputs != 1: - # Find max and min impurities for multi-output - colors['bounds'] = (np.min(-tree.impurity), - np.max(-tree.impurity)) - elif (tree.n_classes[0] == 1 and - len(np.unique(tree.value)) != 1): - # Find max and min values in leaf nodes for regression - colors['bounds'] = (np.min(tree.value), - np.max(tree.value)) - if tree.n_outputs == 1: - node_val = (tree.value[node_id][0, :] / - tree.weighted_n_node_samples[node_id]) - if tree.n_classes[0] == 1: - # Regression - node_val = tree.value[node_id][0, :] - else: - # If multi-output color node by impurity - node_val = -tree.impurity[node_id] - out_file.write(', fillcolor="%s"' % get_color(node_val)) - out_file.write('] ;\n') + self.ranks[str(depth)].append(str(node_id)) + + self.out_file.write( + '%d [label=%s' % (node_id, self.node_to_str(tree, node_id, + criterion))) + + if self.filled: + self.out_file.write(', fillcolor="%s"' + % self.get_fill_color(tree, node_id)) + self.out_file.write('] ;\n') if parent is not None: # Add edge to parent - out_file.write('%d -> %d' % (parent, node_id)) + self.out_file.write('%d -> %d' % (parent, node_id)) if parent == 0: # Draw True/False labels if parent is root node - angles = np.array([45, -45]) * ((rotate - .5) * -2) - out_file.write(' [labeldistance=2.5, labelangle=') + angles = np.array([45, -45]) * ((self.rotate - .5) * -2) + self.out_file.write(' [labeldistance=2.5, labelangle=') if node_id == 1: - out_file.write('%d, headlabel="True"]' % angles[0]) + self.out_file.write('%d, headlabel="True"]' % + angles[0]) else: - out_file.write('%d, headlabel="False"]' % angles[1]) - out_file.write(' ;\n') + self.out_file.write('%d, headlabel="False"]' % + angles[1]) + self.out_file.write(' ;\n') if left_child != _tree.TREE_LEAF: - recurse(tree, left_child, criterion=criterion, parent=node_id, - depth=depth + 1) - recurse(tree, right_child, criterion=criterion, parent=node_id, - depth=depth + 1) + self.recurse(tree, left_child, criterion=criterion, + parent=node_id, depth=depth + 1) + self.recurse(tree, right_child, criterion=criterion, + parent=node_id, depth=depth + 1) else: - ranks['leaves'].append(str(node_id)) + self.ranks['leaves'].append(str(node_id)) - out_file.write('%d [label="(...)"' % node_id) - if filled: + self.out_file.write('%d [label="(...)"' % node_id) + if self.filled: # color cropped nodes grey - out_file.write(', fillcolor="#C0C0C0"') - out_file.write('] ;\n' % node_id) + self.out_file.write(', fillcolor="#C0C0C0"') + self.out_file.write('] ;\n' % node_id) if parent is not None: # Add edge to parent - out_file.write('%d -> %d ;\n' % (parent, node_id)) + self.out_file.write('%d -> %d ;\n' % (parent, node_id)) - check_is_fitted(decision_tree, 'tree_') - own_file = False - return_string = False - try: - if isinstance(out_file, six.string_types): - if six.PY3: - out_file = open(out_file, "w", encoding="utf-8") - else: - out_file = open(out_file, "wb") - own_file = True - if out_file is None: - return_string = True - out_file = six.StringIO() +class _MPLTreeExporter(_BaseTreeExporter): + def __init__(self, max_depth=None, feature_names=None, + class_names=None, label='all', filled=False, + impurity=True, node_ids=False, + proportion=False, rotate=False, rounded=False, + precision=3, fontsize=None): + + super(_MPLTreeExporter, self).__init__( + max_depth=max_depth, feature_names=feature_names, + class_names=class_names, label=label, filled=filled, + impurity=impurity, node_ids=node_ids, proportion=proportion, + rotate=rotate, rounded=rounded, precision=precision) + self.fontsize = fontsize + # validate if isinstance(precision, Integral): if precision < 0: raise ValueError("'precision' should be greater or equal to 0." @@ -416,57 +529,254 @@ def recurse(tree, node_id, criterion, parent=None, depth=0): raise ValueError("'precision' should be an integer. Got {}" " instead.".format(type(precision))) - # Check length of feature_names before getting into the tree node - # Raise error if length of feature_names does not match - # n_features_ in the decision_tree - if feature_names is not None: - if len(feature_names) != decision_tree.n_features_: - raise ValueError("Length of feature_names, %d " - "does not match number of features, %d" - % (len(feature_names), - decision_tree.n_features_)) - # The depth of each node for plotting with 'leaf' option - ranks = {'leaves': []} + self.ranks = {'leaves': []} # The colors to render each node with - colors = {'bounds': None} + self.colors = {'bounds': None} - out_file.write('digraph Tree {\n') + self.characters = ['#', '[', ']', '<=', '\n', '', ''] - # Specify node aesthetics - out_file.write('node [shape=box') - rounded_filled = [] - if filled: - rounded_filled.append('filled') - if rounded: - rounded_filled.append('rounded') - if len(rounded_filled) > 0: - out_file.write(', style="%s", color="black"' - % ", ".join(rounded_filled)) - if rounded: - out_file.write(', fontname=helvetica') - out_file.write('] ;\n') + self.bbox_args = dict(fc='w') + if self.rounded: + self.bbox_args['boxstyle'] = "round" + else: + # matplotlib <1.5 requires explicit boxstyle + self.bbox_args['boxstyle'] = "square" + + self.arrow_args = dict(arrowstyle="<-") + + def _make_tree(self, node_id, et, depth=0): + # traverses _tree.Tree recursively, builds intermediate + # "_reingold_tilford.Tree" object + name = self.node_to_str(et, node_id, criterion='entropy') + if (et.children_left[node_id] != _tree.TREE_LEAF + and (self.max_depth is None or depth <= self.max_depth)): + children = [self._make_tree(et.children_left[node_id], et, + depth=depth + 1), + self._make_tree(et.children_right[node_id], et, + depth=depth + 1)] + else: + return Tree(name, node_id) + return Tree(name, node_id, *children) + + def export(self, decision_tree, ax=None): + import matplotlib.pyplot as plt + from matplotlib.text import Annotation + if ax is None: + ax = plt.gca() + ax.clear() + ax.set_axis_off() + my_tree = self._make_tree(0, decision_tree.tree_) + draw_tree = buchheim(my_tree) + + # important to make sure we're still + # inside the axis after drawing the box + # this makes sense because the width of a box + # is about the same as the distance between boxes + max_x, max_y = draw_tree.max_extents() + 1 + ax_width = ax.get_window_extent().width + ax_height = ax.get_window_extent().height + + scale_x = ax_width / max_x + scale_y = ax_height / max_y + + self.recurse(draw_tree, decision_tree.tree_, ax, + scale_x, scale_y, ax_height) + + anns = [ann for ann in ax.get_children() + if isinstance(ann, Annotation)] + + # update sizes of all bboxes + renderer = ax.figure.canvas.get_renderer() + + for ann in anns: + ann.update_bbox_position_size(renderer) + + if self.fontsize is None: + # get figure to data transform + # adjust fontsize to avoid overlap + # get max box width and height + try: + extents = [ann.get_bbox_patch().get_window_extent() + for ann in anns] + max_width = max([extent.width for extent in extents]) + max_height = max([extent.height for extent in extents]) + # width should be around scale_x in axis coordinates + size = anns[0].get_fontsize() * min(scale_x / max_width, + scale_y / max_height) + for ann in anns: + ann.set_fontsize(size) + except AttributeError: + # matplotlib < 1.5 + warnings.warn("Automatic scaling of tree plots requires " + "matplotlib 1.5 or higher. Please specify " + "fontsize.") + + return anns + + def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0): + # need to copy bbox args because matplotib <1.5 modifies them + kwargs = dict(bbox=self.bbox_args.copy(), ha='center', va='center', + zorder=100 - 10 * depth, xycoords='axes pixels') + + if self.fontsize is not None: + kwargs['fontsize'] = self.fontsize + + # offset things by .5 to center them in plot + xy = ((node.x + .5) * scale_x, height - (node.y + .5) * scale_y) + + if self.max_depth is None or depth <= self.max_depth: + if self.filled: + kwargs['bbox']['fc'] = self.get_fill_color(tree, + node.tree.node_id) + if node.parent is None: + # root + ax.annotate(node.tree.label, xy, **kwargs) + else: + xy_parent = ((node.parent.x + .5) * scale_x, + height - (node.parent.y + .5) * scale_y) + kwargs["arrowprops"] = self.arrow_args + ax.annotate(node.tree.label, xy_parent, xy, **kwargs) + for child in node.children: + self.recurse(child, tree, ax, scale_x, scale_y, height, + depth=depth + 1) - # Specify graph & edge aesthetics - if leaves_parallel: - out_file.write('graph [ranksep=equally, splines=polyline] ;\n') - if rounded: - out_file.write('edge [fontname=helvetica] ;\n') - if rotate: - out_file.write('rankdir=LR ;\n') + else: + xy_parent = ((node.parent.x + .5) * scale_x, + height - (node.parent.y + .5) * scale_y) + kwargs["arrowprops"] = self.arrow_args + kwargs['bbox']['fc'] = 'grey' + ax.annotate("\n (...) \n", xy_parent, xy, **kwargs) - # Now recurse the tree and add node & edge attributes - recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion) - # If required, draw leaf nodes at same depth as each other - if leaves_parallel: - for rank in sorted(ranks): - out_file.write("{rank=same ; " + - "; ".join(r for r in ranks[rank]) + "} ;\n") - out_file.write("}") +def export_graphviz(decision_tree, out_file=None, max_depth=None, + feature_names=None, class_names=None, label='all', + filled=False, leaves_parallel=False, impurity=True, + node_ids=False, proportion=False, rotate=False, + rounded=False, special_characters=False, precision=3): + """Export a decision tree in DOT format. + + This function generates a GraphViz representation of the decision tree, + which is then written into `out_file`. Once exported, graphical renderings + can be generated using, for example:: + + $ dot -Tps tree.dot -o tree.ps (PostScript format) + $ dot -Tpng tree.dot -o tree.png (PNG format) + + The sample counts that are shown are weighted with any sample_weights that + might be present. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + decision_tree : decision tree classifier + The decision tree to be exported to GraphViz. + + out_file : file object or string, optional (default=None) + Handle or name of the output file. If ``None``, the result is + returned as a string. + + .. versionchanged:: 0.20 + Default of out_file changed from "tree.dot" to None. + + max_depth : int, optional (default=None) + The maximum depth of the representation. If None, the tree is fully + generated. + + feature_names : list of strings, optional (default=None) + Names of each of the features. + + class_names : list of strings, bool or None, optional (default=None) + Names of each of the target classes in ascending numerical order. + Only relevant for classification and not supported for multi-output. + If ``True``, shows a symbolic representation of the class name. + + label : {'all', 'root', 'none'}, optional (default='all') + Whether to show informative labels for impurity, etc. + Options include 'all' to show at every node, 'root' to show only at + the top root node, or 'none' to not show at any node. + + filled : bool, optional (default=False) + When set to ``True``, paint nodes to indicate majority class for + classification, extremity of values for regression, or purity of node + for multi-output. + + leaves_parallel : bool, optional (default=False) + When set to ``True``, draw all leaf nodes at the bottom of the tree. + + impurity : bool, optional (default=True) + When set to ``True``, show the impurity at each node. + + node_ids : bool, optional (default=False) + When set to ``True``, show the ID number on each node. + + proportion : bool, optional (default=False) + When set to ``True``, change the display of 'values' and/or 'samples' + to be proportions and percentages respectively. + + rotate : bool, optional (default=False) + When set to ``True``, orient tree left to right rather than top-down. + + rounded : bool, optional (default=False) + When set to ``True``, draw node boxes with rounded corners and use + Helvetica fonts instead of Times-Roman. + + special_characters : bool, optional (default=False) + When set to ``False``, ignore special characters for PostScript + compatibility. + + precision : int, optional (default=3) + Number of digits of precision for floating point in the values of + impurity, threshold and value attributes of each node. + + Returns + ------- + dot_data : string + String representation of the input tree in GraphViz dot format. + Only returned if ``out_file`` is None. + + .. versionadded:: 0.18 + + Examples + -------- + >>> from sklearn.datasets import load_iris + >>> from sklearn import tree + + >>> clf = tree.DecisionTreeClassifier() + >>> iris = load_iris() + + >>> clf = clf.fit(iris.data, iris.target) + >>> tree.export_graphviz(clf) # doctest: +ELLIPSIS + 'digraph Tree {... + """ + + check_is_fitted(decision_tree, 'tree_') + own_file = False + return_string = False + try: + if isinstance(out_file, six.string_types): + if six.PY3: + out_file = open(out_file, "w", encoding="utf-8") + else: + out_file = open(out_file, "wb") + own_file = True + + if out_file is None: + return_string = True + out_file = six.StringIO() + + exporter = _DOTTreeExporter( + out_file=out_file, max_depth=max_depth, + feature_names=feature_names, class_names=class_names, label=label, + filled=filled, leaves_parallel=leaves_parallel, impurity=impurity, + node_ids=node_ids, proportion=proportion, rotate=rotate, + rounded=rounded, special_characters=special_characters, + precision=precision) + exporter.export(decision_tree) if return_string: - return out_file.getvalue() + return exporter.out_file.getvalue() finally: if own_file: diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index c43e0a4f32392..2471914fa44ce 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -1,6 +1,7 @@ """ Testing for export functions of decision trees (sklearn.tree.export). """ +import pytest from re import finditer, search @@ -9,7 +10,7 @@ from sklearn.base import is_classifier from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.ensemble import GradientBoostingClassifier -from sklearn.tree import export_graphviz +from sklearn.tree import export_graphviz, plot_tree from sklearn.externals.six import StringIO from sklearn.utils.testing import (assert_in, assert_equal, assert_raises, assert_less_equal, assert_raises_regex, @@ -92,13 +93,13 @@ def test_graphviz_toy(): 'fontname=helvetica] ;\n' \ 'edge [fontname=helvetica] ;\n' \ '0 [label=0 ≤ 0.0
samples = 100.0%
' \ - 'value = [0.5, 0.5]>, fillcolor="#e5813900"] ;\n' \ + 'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n' \ '1 [label=value = [1.0, 0.0]>, ' \ - 'fillcolor="#e58139ff"] ;\n' \ + 'fillcolor="#e58139"] ;\n' \ '0 -> 1 [labeldistance=2.5, labelangle=45, ' \ 'headlabel="True"] ;\n' \ '2 [label=value = [0.0, 1.0]>, ' \ - 'fillcolor="#399de5ff"] ;\n' \ + 'fillcolor="#399de5"] ;\n' \ '0 -> 2 [labeldistance=2.5, labelangle=-45, ' \ 'headlabel="False"] ;\n' \ '}' @@ -126,7 +127,7 @@ def test_graphviz_toy(): contents2 = 'digraph Tree {\n' \ 'node [shape=box, style="filled", color="black"] ;\n' \ '0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n' \ - 'samples = 6\\nvalue = [3, 3]", fillcolor="#e5813900"] ;\n' \ + 'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n' \ '1 [label="(...)", fillcolor="#C0C0C0"] ;\n' \ '0 -> 1 ;\n' \ '2 [label="(...)", fillcolor="#C0C0C0"] ;\n' \ @@ -148,21 +149,21 @@ def test_graphviz_toy(): 'node [shape=box, style="filled", color="black"] ;\n' \ '0 [label="X[0] <= 0.0\\nsamples = 6\\n' \ 'value = [[3.0, 1.5, 0.0]\\n' \ - '[3.0, 1.0, 0.5]]", fillcolor="#e5813900"] ;\n' \ + '[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n' \ '1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n' \ - '[3, 0, 0]]", fillcolor="#e58139ff"] ;\n' \ + '[3, 0, 0]]", fillcolor="#e58139"] ;\n' \ '0 -> 1 [labeldistance=2.5, labelangle=45, ' \ 'headlabel="True"] ;\n' \ '2 [label="X[0] <= 1.5\\nsamples = 3\\n' \ 'value = [[0.0, 1.5, 0.0]\\n' \ - '[0.0, 1.0, 0.5]]", fillcolor="#e5813986"] ;\n' \ + '[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n' \ '0 -> 2 [labeldistance=2.5, labelangle=-45, ' \ 'headlabel="False"] ;\n' \ '3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n' \ - '[0, 1, 0]]", fillcolor="#e58139ff"] ;\n' \ + '[0, 1, 0]]", fillcolor="#e58139"] ;\n' \ '2 -> 3 ;\n' \ '4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n' \ - '[0.0, 0.0, 0.5]]", fillcolor="#e58139ff"] ;\n' \ + '[0.0, 0.0, 0.5]]", fillcolor="#e58139"] ;\n' \ '2 -> 4 ;\n' \ '}' @@ -184,13 +185,13 @@ def test_graphviz_toy(): 'edge [fontname=helvetica] ;\n' \ 'rankdir=LR ;\n' \ '0 [label="X[0] <= 0.0\\nmse = 1.0\\nsamples = 6\\n' \ - 'value = 0.0", fillcolor="#e5813980"] ;\n' \ + 'value = 0.0", fillcolor="#f2c09c"] ;\n' \ '1 [label="mse = 0.0\\nsamples = 3\\nvalue = -1.0", ' \ - 'fillcolor="#e5813900"] ;\n' \ + 'fillcolor="#ffffff"] ;\n' \ '0 -> 1 [labeldistance=2.5, labelangle=-45, ' \ 'headlabel="True"] ;\n' \ '2 [label="mse = 0.0\\nsamples = 3\\nvalue = 1.0", ' \ - 'fillcolor="#e58139ff"] ;\n' \ + 'fillcolor="#e58139"] ;\n' \ '0 -> 2 [labeldistance=2.5, labelangle=45, ' \ 'headlabel="False"] ;\n' \ '{rank=same ; 0} ;\n' \ @@ -207,7 +208,7 @@ def test_graphviz_toy(): contents2 = 'digraph Tree {\n' \ 'node [shape=box, style="filled", color="black"] ;\n' \ '0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", ' \ - 'fillcolor="#e5813900"] ;\n' \ + 'fillcolor="#ffffff"] ;\n' \ '}' @@ -308,3 +309,23 @@ def test_precision(): for finding in finditer(r"<= \d+\.\d+", dot_data): assert_equal(len(search(r"\.\d+", finding.group()).group()), precision + 1) + + +def test_plot_tree(): + # mostly smoke tests + pytest.importorskip("matplotlib.pyplot") + # Check correctness of export_graphviz + clf = DecisionTreeClassifier(max_depth=3, + min_samples_split=2, + criterion="gini", + random_state=2) + clf.fit(X, y) + + # Test export code + feature_names = ['first feat', 'sepal_width'] + nodes = plot_tree(clf, feature_names=feature_names) + assert len(nodes) == 3 + assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 0.5\n" + "samples = 6\nvalue = [3, 3]") + assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]" + assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]" diff --git a/sklearn/tree/tests/test_reingold_tilford.py b/sklearn/tree/tests/test_reingold_tilford.py new file mode 100644 index 0000000000000..4cb27ce6effb9 --- /dev/null +++ b/sklearn/tree/tests/test_reingold_tilford.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest +from sklearn.tree._reingold_tilford import buchheim, Tree + +simple_tree = Tree("", 0, + Tree("", 1), + Tree("", 2)) + +bigger_tree = Tree("", 0, + Tree("", 1, + Tree("", 3), + Tree("", 4, + Tree("", 7), + Tree("", 8) + ), + ), + Tree("", 2, + Tree("", 5), + Tree("", 6) + ) + ) + + +@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)]) +def test_buchheim(tree, n_nodes): + def walk_tree(draw_tree): + res = [(draw_tree.x, draw_tree.y)] + for child in draw_tree.children: + # parents higher than children: + assert child.y == draw_tree.y + 1 + res.extend(walk_tree(child)) + if len(draw_tree.children): + # these trees are always binary + # parents are centered above children + assert draw_tree.x == (draw_tree.children[0].x + + draw_tree.children[1].x) / 2 + return res + + layout = buchheim(tree) + coordinates = walk_tree(layout) + assert len(coordinates) == n_nodes + # test that x values are unique per depth / level + # we could also do it quicker using defaultdicts.. + depth = 0 + while True: + x_at_this_depth = [] + for node in coordinates: + if coordinates[1] == depth: + x_at_this_depth.append(coordinates[0]) + if not x_at_this_depth: + # reached all leafs + break + assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth) + depth += 1