diff --git a/doc/Makefile b/doc/Makefile
index fcb547d14e2b0..6629518fc556a 100644
--- a/doc/Makefile
+++ b/doc/Makefile
@@ -13,7 +13,7 @@ endif
# Internal variables.
PAPEROPT_a4 = -D latex_paper_size=a4
PAPEROPT_letter = -D latex_paper_size=letter
-ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS)\
+ALLSPHINXOPTS = -T -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS)\
$(EXAMPLES_PATTERN_OPTS) .
diff --git a/doc/modules/tree.rst b/doc/modules/tree.rst
index 4c3f584b079ab..fe5bed4c0221f 100644
--- a/doc/modules/tree.rst
+++ b/doc/modules/tree.rst
@@ -124,9 +124,20 @@ Using the Iris dataset, we can construct a tree as follows::
>>> clf = tree.DecisionTreeClassifier()
>>> clf = clf.fit(iris.data, iris.target)
-Once trained, we can export the tree in `Graphviz
+Once trained, you can plot the tree with the plot_tree function::
+
+
+ >>> tree.plot_tree(clf.fit(iris.data, iris.target)) # doctest: +SKIP
+
+.. figure:: ../auto_examples/tree/images/sphx_glr_plot_iris_002.png
+ :target: ../auto_examples/tree/plot_iris.html
+ :scale: 75
+ :align: center
+
+We can also export the tree in `Graphviz
`_ format using the :func:`export_graphviz`
-exporter. If you use the `conda `_ package manager, the graphviz binaries
+exporter. If you use the `conda `_ package manager, the graphviz binaries
+
and the python package can be installed with
conda install python-graphviz
diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst
index be200250a9ae5..c2312875b1c68 100644
--- a/doc/whats_new/v0.21.rst
+++ b/doc/whats_new/v0.21.rst
@@ -62,6 +62,11 @@ Support for Python 3.4 and below has been officially dropped.
:mod:`sklearn.tree`
...................
+- Decision Trees can now be plotted with matplotlib using
+ :func:`tree.export.plot_tree` without relying on the ``dot`` library,
+ removing a hard-to-install dependency.
+ :issue:`8508` by `Andreas Müller`_.
+
- |Feature| ``get_n_leaves()`` and ``get_depth()`` have been added to
:class:`tree.BaseDecisionTree` and consequently all estimators based
on it, including :class:`tree.DecisionTreeClassifier`,
diff --git a/examples/tree/plot_iris.py b/examples/tree/plot_iris.py
index f299aab18d7d1..60328c4f90d4f 100644
--- a/examples/tree/plot_iris.py
+++ b/examples/tree/plot_iris.py
@@ -11,6 +11,8 @@
For each pair of iris features, the decision tree learns decision
boundaries made of combinations of simple thresholding rules inferred from
the training samples.
+
+We also show the tree structure of a model built on all of the features.
"""
print(__doc__)
@@ -18,7 +20,7 @@
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
-from sklearn.tree import DecisionTreeClassifier
+from sklearn.tree import DecisionTreeClassifier, plot_tree
# Parameters
n_classes = 3
@@ -62,4 +64,8 @@
plt.suptitle("Decision surface of a decision tree using paired features")
plt.legend(loc='lower right', borderpad=0, handletextpad=0)
plt.axis("tight")
+
+plt.figure()
+clf = DecisionTreeClassifier().fit(iris.data, iris.target)
+plot_tree(clf, filled=True)
plt.show()
diff --git a/sklearn/tree/__init__.py b/sklearn/tree/__init__.py
index 1394bd914d27c..b3abe30d019fa 100644
--- a/sklearn/tree/__init__.py
+++ b/sklearn/tree/__init__.py
@@ -7,7 +7,8 @@
from .tree import DecisionTreeRegressor
from .tree import ExtraTreeClassifier
from .tree import ExtraTreeRegressor
-from .export import export_graphviz
+from .export import export_graphviz, plot_tree
__all__ = ["DecisionTreeClassifier", "DecisionTreeRegressor",
- "ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz"]
+ "ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz",
+ "plot_tree"]
diff --git a/sklearn/tree/_reingold_tilford.py b/sklearn/tree/_reingold_tilford.py
new file mode 100644
index 0000000000000..d83969badb623
--- /dev/null
+++ b/sklearn/tree/_reingold_tilford.py
@@ -0,0 +1,203 @@
+# taken from https://github.com/llimllib/pymag-trees/blob/master/buchheim.py
+# with slight modifications
+
+# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
+# Version 2, December 2004
+#
+# Copyright (C) 2004 Sam Hocevar
+#
+# Everyone is permitted to copy and distribute verbatim or modified
+# copies of this license document, and changing it is allowed as long
+# as the name is changed.
+#
+# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
+# TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
+
+# 0. You just DO WHAT THE FUCK YOU WANT TO.
+
+
+import numpy as np
+
+
+class DrawTree(object):
+ def __init__(self, tree, parent=None, depth=0, number=1):
+ self.x = -1.
+ self.y = depth
+ self.tree = tree
+ self.children = [DrawTree(c, self, depth + 1, i + 1)
+ for i, c
+ in enumerate(tree.children)]
+ self.parent = parent
+ self.thread = None
+ self.mod = 0
+ self.ancestor = self
+ self.change = self.shift = 0
+ self._lmost_sibling = None
+ # this is the number of the node in its group of siblings 1..n
+ self.number = number
+
+ def left(self):
+ return self.thread or len(self.children) and self.children[0]
+
+ def right(self):
+ return self.thread or len(self.children) and self.children[-1]
+
+ def lbrother(self):
+ n = None
+ if self.parent:
+ for node in self.parent.children:
+ if node == self:
+ return n
+ else:
+ n = node
+ return n
+
+ def get_lmost_sibling(self):
+ if not self._lmost_sibling and self.parent and self != \
+ self.parent.children[0]:
+ self._lmost_sibling = self.parent.children[0]
+ return self._lmost_sibling
+ lmost_sibling = property(get_lmost_sibling)
+
+ def __str__(self):
+ return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod)
+
+ def __repr__(self):
+ return self.__str__()
+
+ def max_extents(self):
+ extents = [c.max_extents() for c in self. children]
+ extents.append((self.x, self.y))
+ return np.max(extents, axis=0)
+
+
+def buchheim(tree):
+ dt = first_walk(DrawTree(tree))
+ min = second_walk(dt)
+ if min < 0:
+ third_walk(dt, -min)
+ return dt
+
+
+def third_walk(tree, n):
+ tree.x += n
+ for c in tree.children:
+ third_walk(c, n)
+
+
+def first_walk(v, distance=1.):
+ if len(v.children) == 0:
+ if v.lmost_sibling:
+ v.x = v.lbrother().x + distance
+ else:
+ v.x = 0.
+ else:
+ default_ancestor = v.children[0]
+ for w in v.children:
+ first_walk(w)
+ default_ancestor = apportion(w, default_ancestor, distance)
+ # print("finished v =", v.tree, "children")
+ execute_shifts(v)
+
+ midpoint = (v.children[0].x + v.children[-1].x) / 2
+
+ w = v.lbrother()
+ if w:
+ v.x = w.x + distance
+ v.mod = v.x - midpoint
+ else:
+ v.x = midpoint
+ return v
+
+
+def apportion(v, default_ancestor, distance):
+ w = v.lbrother()
+ if w is not None:
+ # in buchheim notation:
+ # i == inner; o == outer; r == right; l == left; r = +; l = -
+ vir = vor = v
+ vil = w
+ vol = v.lmost_sibling
+ sir = sor = v.mod
+ sil = vil.mod
+ sol = vol.mod
+ while vil.right() and vir.left():
+ vil = vil.right()
+ vir = vir.left()
+ vol = vol.left()
+ vor = vor.right()
+ vor.ancestor = v
+ shift = (vil.x + sil) - (vir.x + sir) + distance
+ if shift > 0:
+ move_subtree(ancestor(vil, v, default_ancestor), v, shift)
+ sir = sir + shift
+ sor = sor + shift
+ sil += vil.mod
+ sir += vir.mod
+ sol += vol.mod
+ sor += vor.mod
+ if vil.right() and not vor.right():
+ vor.thread = vil.right()
+ vor.mod += sil - sor
+ else:
+ if vir.left() and not vol.left():
+ vol.thread = vir.left()
+ vol.mod += sir - sol
+ default_ancestor = v
+ return default_ancestor
+
+
+def move_subtree(wl, wr, shift):
+ subtrees = wr.number - wl.number
+ # print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees,
+ # 'shift', shift)
+ # print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees
+ wr.change -= shift / subtrees
+ wr.shift += shift
+ wl.change += shift / subtrees
+ wr.x += shift
+ wr.mod += shift
+
+
+def execute_shifts(v):
+ shift = change = 0
+ for w in v.children[::-1]:
+ # print("shift:", w, shift, w.change)
+ w.x += shift
+ w.mod += shift
+ change += w.change
+ shift += w.shift + change
+
+
+def ancestor(vil, v, default_ancestor):
+ # the relevant text is at the bottom of page 7 of
+ # "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al,
+ # (2002)
+ # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.16.8757&rep=rep1&type=pdf
+ if vil.ancestor in v.parent.children:
+ return vil.ancestor
+ else:
+ return default_ancestor
+
+
+def second_walk(v, m=0, depth=0, min=None):
+ v.x += m
+ v.y = depth
+
+ if min is None or v.x < min:
+ min = v.x
+
+ for w in v.children:
+ min = second_walk(w, m + v.mod, depth + 1, min)
+
+ return min
+
+
+class Tree(object):
+ def __init__(self, label="", node_id=-1, *children):
+ self.label = label
+ self.node_id = node_id
+ if children:
+ self.children = children
+ else:
+ self.children = []
diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py
index ef13790e65b42..fe127d77302b6 100644
--- a/sklearn/tree/export.py
+++ b/sklearn/tree/export.py
@@ -10,6 +10,7 @@
# Trevor Stephens
# Li Li
# License: BSD 3 clause
+import warnings
from numbers import Integral
@@ -20,6 +21,7 @@
from . import _criterion
from . import _tree
+from ._reingold_tilford import buchheim, Tree
def _color_brew(n):
@@ -72,37 +74,30 @@ def __repr__(self):
SENTINEL = Sentinel()
-def export_graphviz(decision_tree, out_file=None, max_depth=None,
- feature_names=None, class_names=None, label='all',
- filled=False, leaves_parallel=False, impurity=True,
- node_ids=False, proportion=False, rotate=False,
- rounded=False, special_characters=False, precision=3):
- """Export a decision tree in DOT format.
-
- This function generates a GraphViz representation of the decision tree,
- which is then written into `out_file`. Once exported, graphical renderings
- can be generated using, for example::
-
- $ dot -Tps tree.dot -o tree.ps (PostScript format)
- $ dot -Tpng tree.dot -o tree.png (PNG format)
+def plot_tree(decision_tree, max_depth=None, feature_names=None,
+ class_names=None, label='all', filled=False,
+ impurity=True, node_ids=False,
+ proportion=False, rotate=False, rounded=False,
+ precision=3, ax=None, fontsize=None):
+ """Plot a decision tree.
The sample counts that are shown are weighted with any sample_weights that
might be present.
+ This function requires matplotlib, and works best with matplotlib >= 1.5.
+
+ The visualization is fit automatically to the size of the axis.
+ Use the ``figsize`` or ``dpi`` arguments of ``plt.figure`` to control
+ the size of the rendering.
Read more in the :ref:`User Guide `.
+ .. versionadded:: 0.21
+
Parameters
----------
decision_tree : decision tree regressor or classifier
The decision tree to be exported to GraphViz.
- out_file : file object or string, optional (default=None)
- Handle or name of the output file. If ``None``, the result is
- returned as a string.
-
- .. versionchanged:: 0.20
- Default of out_file changed from "tree.dot" to None.
-
max_depth : int, optional (default=None)
The maximum depth of the representation. If None, the tree is fully
generated.
@@ -125,9 +120,6 @@ def export_graphviz(decision_tree, out_file=None, max_depth=None,
classification, extremity of values for regression, or purity of node
for multi-output.
- leaves_parallel : bool, optional (default=False)
- When set to ``True``, draw all leaf nodes at the bottom of the tree.
-
impurity : bool, optional (default=True)
When set to ``True``, show the impurity at each node.
@@ -145,64 +137,113 @@ def export_graphviz(decision_tree, out_file=None, max_depth=None,
When set to ``True``, draw node boxes with rounded corners and use
Helvetica fonts instead of Times-Roman.
- special_characters : bool, optional (default=False)
- When set to ``False``, ignore special characters for PostScript
- compatibility.
-
precision : int, optional (default=3)
Number of digits of precision for floating point in the values of
impurity, threshold and value attributes of each node.
+ ax : matplotlib axis, optional (default=None)
+ Axes to plot to. If None, use current axis. Any previous content
+ is cleared.
+
+ fontsize : int, optional (default=None)
+ Size of text font. If None, determined automatically to fit figure.
+
Returns
-------
- dot_data : string
- String representation of the input tree in GraphViz dot format.
- Only returned if ``out_file`` is None.
-
- .. versionadded:: 0.18
+ annotations : list of artists
+ List containing the artists for the annotation boxes making up the
+ tree.
Examples
--------
>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
- >>> clf = tree.DecisionTreeClassifier()
+ >>> clf = tree.DecisionTreeClassifier(random_state=0)
>>> iris = load_iris()
>>> clf = clf.fit(iris.data, iris.target)
- >>> tree.export_graphviz(clf,
- ... out_file='tree.dot') # doctest: +SKIP
+ >>> tree.plot_tree(clf) # doctest: +SKIP
+ [Text(251.5,345.217,'X[3] <= 0.8...
"""
-
- def get_color(value):
+ exporter = _MPLTreeExporter(
+ max_depth=max_depth, feature_names=feature_names,
+ class_names=class_names, label=label, filled=filled,
+ impurity=impurity, node_ids=node_ids,
+ proportion=proportion, rotate=rotate, rounded=rounded,
+ precision=precision, fontsize=fontsize)
+ return exporter.export(decision_tree, ax=ax)
+
+
+class _BaseTreeExporter(object):
+ def __init__(self, max_depth=None, feature_names=None,
+ class_names=None, label='all', filled=False,
+ impurity=True, node_ids=False,
+ proportion=False, rotate=False, rounded=False,
+ precision=3, fontsize=None):
+ self.max_depth = max_depth
+ self.feature_names = feature_names
+ self.class_names = class_names
+ self.label = label
+ self.filled = filled
+ self.impurity = impurity
+ self.node_ids = node_ids
+ self.proportion = proportion
+ self.rotate = rotate
+ self.rounded = rounded
+ self.precision = precision
+ self.fontsize = fontsize
+
+ def get_color(self, value):
# Find the appropriate color & intensity for a node
- if colors['bounds'] is None:
+ if self.colors['bounds'] is None:
# Classification tree
- color = list(colors['rgb'][np.argmax(value)])
+ color = list(self.colors['rgb'][np.argmax(value)])
sorted_values = sorted(value, reverse=True)
if len(sorted_values) == 1:
alpha = 0
else:
- alpha = int(np.round(255 * (sorted_values[0] -
- sorted_values[1]) /
- (1 - sorted_values[1]), 0))
+ alpha = ((sorted_values[0] - sorted_values[1])
+ / (1 - sorted_values[1]))
else:
# Regression tree or multi-output
- color = list(colors['rgb'][0])
- alpha = int(np.round(255 * ((value - colors['bounds'][0]) /
- (colors['bounds'][1] -
- colors['bounds'][0])), 0))
-
- # Return html color code in #RRGGBBAA format
- color.append(alpha)
- hex_codes = [str(i) for i in range(10)]
- hex_codes.extend(['a', 'b', 'c', 'd', 'e', 'f'])
- color = [hex_codes[c // 16] + hex_codes[c % 16] for c in color]
-
- return '#' + ''.join(color)
+ color = list(self.colors['rgb'][0])
+ alpha = ((value - self.colors['bounds'][0]) /
+ (self.colors['bounds'][1] - self.colors['bounds'][0]))
+ # unpack numpy scalars
+ alpha = float(alpha)
+ # compute the color as alpha against white
+ color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color]
+ # Return html color code in #RRGGBB format
+ return '#%2x%2x%2x' % tuple(color)
+
+ def get_fill_color(self, tree, node_id):
+ # Fetch appropriate color for node
+ if 'rgb' not in self.colors:
+ # Initialize colors and bounds if required
+ self.colors['rgb'] = _color_brew(tree.n_classes[0])
+ if tree.n_outputs != 1:
+ # Find max and min impurities for multi-output
+ self.colors['bounds'] = (np.min(-tree.impurity),
+ np.max(-tree.impurity))
+ elif (tree.n_classes[0] == 1 and
+ len(np.unique(tree.value)) != 1):
+ # Find max and min values in leaf nodes for regression
+ self.colors['bounds'] = (np.min(tree.value),
+ np.max(tree.value))
+ if tree.n_outputs == 1:
+ node_val = (tree.value[node_id][0, :] /
+ tree.weighted_n_node_samples[node_id])
+ if tree.n_classes[0] == 1:
+ # Regression
+ node_val = tree.value[node_id][0, :]
+ else:
+ # If multi-output color node by impurity
+ node_val = -tree.impurity[node_id]
+ return self.get_color(node_val)
- def node_to_str(tree, node_id, criterion):
+ def node_to_str(self, tree, node_id, criterion):
# Generate the node content string
if tree.n_outputs == 1:
value = tree.value[node_id][0, :]
@@ -210,18 +251,13 @@ def node_to_str(tree, node_id, criterion):
value = tree.value[node_id]
# Should labels be shown?
- labels = (label == 'root' and node_id == 0) or label == 'all'
+ labels = (self.label == 'root' and node_id == 0) or self.label == 'all'
- # PostScript compatibility for special characters
- if special_characters:
- characters = ['#', '', '', '≤', '
', '>']
- node_string = '<'
- else:
- characters = ['#', '[', ']', '<=', '\\n', '"']
- node_string = '"'
+ characters = self.characters
+ node_string = characters[-1]
# Write node ID
- if node_ids:
+ if self.node_ids:
if labels:
node_string += 'node '
node_string += characters[0] + str(node_id) + characters[4]
@@ -229,8 +265,8 @@ def node_to_str(tree, node_id, criterion):
# Write decision criteria
if tree.children_left[node_id] != _tree.TREE_LEAF:
# Always write node decision criteria, except for leaves
- if feature_names is not None:
- feature = feature_names[tree.feature[node_id]]
+ if self.feature_names is not None:
+ feature = self.feature_names[tree.feature[node_id]]
else:
feature = "X%s%s%s" % (characters[1],
tree.feature[node_id],
@@ -238,24 +274,24 @@ def node_to_str(tree, node_id, criterion):
node_string += '%s %s %s%s' % (feature,
characters[3],
round(tree.threshold[node_id],
- precision),
+ self.precision),
characters[4])
# Write impurity
- if impurity:
+ if self.impurity:
if isinstance(criterion, _criterion.FriedmanMSE):
criterion = "friedman_mse"
elif not isinstance(criterion, six.string_types):
criterion = "impurity"
if labels:
node_string += '%s = ' % criterion
- node_string += (str(round(tree.impurity[node_id], precision)) +
- characters[4])
+ node_string += (str(round(tree.impurity[node_id], self.precision))
+ + characters[4])
# Write node sample count
if labels:
node_string += 'samples = '
- if proportion:
+ if self.proportion:
percent = (100. * tree.n_node_samples[node_id] /
float(tree.n_node_samples[0]))
node_string += (str(round(percent, 1)) + '%' +
@@ -265,23 +301,23 @@ def node_to_str(tree, node_id, criterion):
characters[4])
# Write node class distribution / regression value
- if proportion and tree.n_classes[0] != 1:
+ if self.proportion and tree.n_classes[0] != 1:
# For classification this will show the proportion of samples
value = value / tree.weighted_n_node_samples[node_id]
if labels:
node_string += 'value = '
if tree.n_classes[0] == 1:
# Regression
- value_text = np.around(value, precision)
- elif proportion:
+ value_text = np.around(value, self.precision)
+ elif self.proportion:
# Classification
- value_text = np.around(value, precision)
+ value_text = np.around(value, self.precision)
elif np.all(np.equal(np.mod(value, 1), 0)):
# Classification without floating-point weights
value_text = value.astype(int)
else:
# Classification with floating-point weights
- value_text = np.around(value, precision)
+ value_text = np.around(value, self.precision)
# Strip whitespace
value_text = str(value_text.astype('S32')).replace("b'", "'")
value_text = value_text.replace("' '", ", ").replace("'", "")
@@ -291,14 +327,14 @@ def node_to_str(tree, node_id, criterion):
node_string += value_text + characters[4]
# Write node majority class
- if (class_names is not None and
+ if (self.class_names is not None and
tree.n_classes[0] != 1 and
tree.n_outputs == 1):
# Only done for single-output classification trees
if labels:
node_string += 'class = '
- if class_names is not True:
- class_name = class_names[np.argmax(value)]
+ if self.class_names is not True:
+ class_name = self.class_names[np.argmax(value)]
else:
class_name = "y%s%s%s" % (characters[1],
np.argmax(value),
@@ -306,14 +342,109 @@ def node_to_str(tree, node_id, criterion):
node_string += class_name
# Clean up any trailing newlines
- if node_string[-2:] == '\\n':
- node_string = node_string[:-2]
- if node_string[-5:] == '
':
- node_string = node_string[:-5]
+ if node_string.endswith(characters[4]):
+ node_string = node_string[:-len(characters[4])]
return node_string + characters[5]
- def recurse(tree, node_id, criterion, parent=None, depth=0):
+
+class _DOTTreeExporter(_BaseTreeExporter):
+ def __init__(self, out_file=SENTINEL, max_depth=None,
+ feature_names=None, class_names=None, label='all',
+ filled=False, leaves_parallel=False, impurity=True,
+ node_ids=False, proportion=False, rotate=False, rounded=False,
+ special_characters=False, precision=3):
+
+ super(_DOTTreeExporter, self).__init__(
+ max_depth=max_depth, feature_names=feature_names,
+ class_names=class_names, label=label, filled=filled,
+ impurity=impurity,
+ node_ids=node_ids, proportion=proportion, rotate=rotate,
+ rounded=rounded,
+ precision=precision)
+ self.leaves_parallel = leaves_parallel
+ self.out_file = out_file
+ self.special_characters = special_characters
+
+ # PostScript compatibility for special characters
+ if special_characters:
+ self.characters = ['#', '', '', '≤', '
',
+ '>', '<']
+ else:
+ self.characters = ['#', '[', ']', '<=', '\\n', '"', '"']
+
+ # validate
+ if isinstance(precision, Integral):
+ if precision < 0:
+ raise ValueError("'precision' should be greater or equal to 0."
+ " Got {} instead.".format(precision))
+ else:
+ raise ValueError("'precision' should be an integer. Got {}"
+ " instead.".format(type(precision)))
+
+ # The depth of each node for plotting with 'leaf' option
+ self.ranks = {'leaves': []}
+ # The colors to render each node with
+ self.colors = {'bounds': None}
+
+ def export(self, decision_tree):
+ # Check length of feature_names before getting into the tree node
+ # Raise error if length of feature_names does not match
+ # n_features_ in the decision_tree
+ if self.feature_names is not None:
+ if len(self.feature_names) != decision_tree.n_features_:
+ raise ValueError("Length of feature_names, %d "
+ "does not match number of features, %d"
+ % (len(self.feature_names),
+ decision_tree.n_features_))
+ # each part writes to out_file
+ self.head()
+ # Now recurse the tree and add node & edge attributes
+ if isinstance(decision_tree, _tree.Tree):
+ self.recurse(decision_tree, 0, criterion="impurity")
+ else:
+ self.recurse(decision_tree.tree_, 0,
+ criterion=decision_tree.criterion)
+
+ self.tail()
+
+ def tail(self):
+ # If required, draw leaf nodes at same depth as each other
+ if self.leaves_parallel:
+ for rank in sorted(self.ranks):
+ self.out_file.write(
+ "{rank=same ; " +
+ "; ".join(r for r in self.ranks[rank]) + "} ;\n")
+ self.out_file.write("}")
+
+ def head(self):
+ self.out_file.write('digraph Tree {\n')
+
+ # Specify node aesthetics
+ self.out_file.write('node [shape=box')
+ rounded_filled = []
+ if self.filled:
+ rounded_filled.append('filled')
+ if self.rounded:
+ rounded_filled.append('rounded')
+ if len(rounded_filled) > 0:
+ self.out_file.write(
+ ', style="%s", color="black"'
+ % ", ".join(rounded_filled))
+ if self.rounded:
+ self.out_file.write(', fontname=helvetica')
+ self.out_file.write('] ;\n')
+
+ # Specify graph & edge aesthetics
+ if self.leaves_parallel:
+ self.out_file.write(
+ 'graph [ranksep=equally, splines=polyline] ;\n')
+ if self.rounded:
+ self.out_file.write('edge [fontname=helvetica] ;\n')
+ if self.rotate:
+ self.out_file.write('rankdir=LR ;\n')
+
+ def recurse(self, tree, node_id, criterion, parent=None, depth=0):
if node_id == _tree.TREE_LEAF:
raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF)
@@ -321,93 +452,75 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
right_child = tree.children_right[node_id]
# Add node with description
- if max_depth is None or depth <= max_depth:
+ if self.max_depth is None or depth <= self.max_depth:
# Collect ranks for 'leaf' option in plot_options
if left_child == _tree.TREE_LEAF:
- ranks['leaves'].append(str(node_id))
- elif str(depth) not in ranks:
- ranks[str(depth)] = [str(node_id)]
+ self.ranks['leaves'].append(str(node_id))
+ elif str(depth) not in self.ranks:
+ self.ranks[str(depth)] = [str(node_id)]
else:
- ranks[str(depth)].append(str(node_id))
-
- out_file.write('%d [label=%s'
- % (node_id,
- node_to_str(tree, node_id, criterion)))
-
- if filled:
- # Fetch appropriate color for node
- if 'rgb' not in colors:
- # Initialize colors and bounds if required
- colors['rgb'] = _color_brew(tree.n_classes[0])
- if tree.n_outputs != 1:
- # Find max and min impurities for multi-output
- colors['bounds'] = (np.min(-tree.impurity),
- np.max(-tree.impurity))
- elif (tree.n_classes[0] == 1 and
- len(np.unique(tree.value)) != 1):
- # Find max and min values in leaf nodes for regression
- colors['bounds'] = (np.min(tree.value),
- np.max(tree.value))
- if tree.n_outputs == 1:
- node_val = (tree.value[node_id][0, :] /
- tree.weighted_n_node_samples[node_id])
- if tree.n_classes[0] == 1:
- # Regression
- node_val = tree.value[node_id][0, :]
- else:
- # If multi-output color node by impurity
- node_val = -tree.impurity[node_id]
- out_file.write(', fillcolor="%s"' % get_color(node_val))
- out_file.write('] ;\n')
+ self.ranks[str(depth)].append(str(node_id))
+
+ self.out_file.write(
+ '%d [label=%s' % (node_id, self.node_to_str(tree, node_id,
+ criterion)))
+
+ if self.filled:
+ self.out_file.write(', fillcolor="%s"'
+ % self.get_fill_color(tree, node_id))
+ self.out_file.write('] ;\n')
if parent is not None:
# Add edge to parent
- out_file.write('%d -> %d' % (parent, node_id))
+ self.out_file.write('%d -> %d' % (parent, node_id))
if parent == 0:
# Draw True/False labels if parent is root node
- angles = np.array([45, -45]) * ((rotate - .5) * -2)
- out_file.write(' [labeldistance=2.5, labelangle=')
+ angles = np.array([45, -45]) * ((self.rotate - .5) * -2)
+ self.out_file.write(' [labeldistance=2.5, labelangle=')
if node_id == 1:
- out_file.write('%d, headlabel="True"]' % angles[0])
+ self.out_file.write('%d, headlabel="True"]' %
+ angles[0])
else:
- out_file.write('%d, headlabel="False"]' % angles[1])
- out_file.write(' ;\n')
+ self.out_file.write('%d, headlabel="False"]' %
+ angles[1])
+ self.out_file.write(' ;\n')
if left_child != _tree.TREE_LEAF:
- recurse(tree, left_child, criterion=criterion, parent=node_id,
- depth=depth + 1)
- recurse(tree, right_child, criterion=criterion, parent=node_id,
- depth=depth + 1)
+ self.recurse(tree, left_child, criterion=criterion,
+ parent=node_id, depth=depth + 1)
+ self.recurse(tree, right_child, criterion=criterion,
+ parent=node_id, depth=depth + 1)
else:
- ranks['leaves'].append(str(node_id))
+ self.ranks['leaves'].append(str(node_id))
- out_file.write('%d [label="(...)"' % node_id)
- if filled:
+ self.out_file.write('%d [label="(...)"' % node_id)
+ if self.filled:
# color cropped nodes grey
- out_file.write(', fillcolor="#C0C0C0"')
- out_file.write('] ;\n' % node_id)
+ self.out_file.write(', fillcolor="#C0C0C0"')
+ self.out_file.write('] ;\n' % node_id)
if parent is not None:
# Add edge to parent
- out_file.write('%d -> %d ;\n' % (parent, node_id))
+ self.out_file.write('%d -> %d ;\n' % (parent, node_id))
- check_is_fitted(decision_tree, 'tree_')
- own_file = False
- return_string = False
- try:
- if isinstance(out_file, six.string_types):
- if six.PY3:
- out_file = open(out_file, "w", encoding="utf-8")
- else:
- out_file = open(out_file, "wb")
- own_file = True
- if out_file is None:
- return_string = True
- out_file = six.StringIO()
+class _MPLTreeExporter(_BaseTreeExporter):
+ def __init__(self, max_depth=None, feature_names=None,
+ class_names=None, label='all', filled=False,
+ impurity=True, node_ids=False,
+ proportion=False, rotate=False, rounded=False,
+ precision=3, fontsize=None):
+
+ super(_MPLTreeExporter, self).__init__(
+ max_depth=max_depth, feature_names=feature_names,
+ class_names=class_names, label=label, filled=filled,
+ impurity=impurity, node_ids=node_ids, proportion=proportion,
+ rotate=rotate, rounded=rounded, precision=precision)
+ self.fontsize = fontsize
+ # validate
if isinstance(precision, Integral):
if precision < 0:
raise ValueError("'precision' should be greater or equal to 0."
@@ -416,57 +529,254 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
raise ValueError("'precision' should be an integer. Got {}"
" instead.".format(type(precision)))
- # Check length of feature_names before getting into the tree node
- # Raise error if length of feature_names does not match
- # n_features_ in the decision_tree
- if feature_names is not None:
- if len(feature_names) != decision_tree.n_features_:
- raise ValueError("Length of feature_names, %d "
- "does not match number of features, %d"
- % (len(feature_names),
- decision_tree.n_features_))
-
# The depth of each node for plotting with 'leaf' option
- ranks = {'leaves': []}
+ self.ranks = {'leaves': []}
# The colors to render each node with
- colors = {'bounds': None}
+ self.colors = {'bounds': None}
- out_file.write('digraph Tree {\n')
+ self.characters = ['#', '[', ']', '<=', '\n', '', '']
- # Specify node aesthetics
- out_file.write('node [shape=box')
- rounded_filled = []
- if filled:
- rounded_filled.append('filled')
- if rounded:
- rounded_filled.append('rounded')
- if len(rounded_filled) > 0:
- out_file.write(', style="%s", color="black"'
- % ", ".join(rounded_filled))
- if rounded:
- out_file.write(', fontname=helvetica')
- out_file.write('] ;\n')
+ self.bbox_args = dict(fc='w')
+ if self.rounded:
+ self.bbox_args['boxstyle'] = "round"
+ else:
+ # matplotlib <1.5 requires explicit boxstyle
+ self.bbox_args['boxstyle'] = "square"
+
+ self.arrow_args = dict(arrowstyle="<-")
+
+ def _make_tree(self, node_id, et, depth=0):
+ # traverses _tree.Tree recursively, builds intermediate
+ # "_reingold_tilford.Tree" object
+ name = self.node_to_str(et, node_id, criterion='entropy')
+ if (et.children_left[node_id] != _tree.TREE_LEAF
+ and (self.max_depth is None or depth <= self.max_depth)):
+ children = [self._make_tree(et.children_left[node_id], et,
+ depth=depth + 1),
+ self._make_tree(et.children_right[node_id], et,
+ depth=depth + 1)]
+ else:
+ return Tree(name, node_id)
+ return Tree(name, node_id, *children)
+
+ def export(self, decision_tree, ax=None):
+ import matplotlib.pyplot as plt
+ from matplotlib.text import Annotation
+ if ax is None:
+ ax = plt.gca()
+ ax.clear()
+ ax.set_axis_off()
+ my_tree = self._make_tree(0, decision_tree.tree_)
+ draw_tree = buchheim(my_tree)
+
+ # important to make sure we're still
+ # inside the axis after drawing the box
+ # this makes sense because the width of a box
+ # is about the same as the distance between boxes
+ max_x, max_y = draw_tree.max_extents() + 1
+ ax_width = ax.get_window_extent().width
+ ax_height = ax.get_window_extent().height
+
+ scale_x = ax_width / max_x
+ scale_y = ax_height / max_y
+
+ self.recurse(draw_tree, decision_tree.tree_, ax,
+ scale_x, scale_y, ax_height)
+
+ anns = [ann for ann in ax.get_children()
+ if isinstance(ann, Annotation)]
+
+ # update sizes of all bboxes
+ renderer = ax.figure.canvas.get_renderer()
+
+ for ann in anns:
+ ann.update_bbox_position_size(renderer)
+
+ if self.fontsize is None:
+ # get figure to data transform
+ # adjust fontsize to avoid overlap
+ # get max box width and height
+ try:
+ extents = [ann.get_bbox_patch().get_window_extent()
+ for ann in anns]
+ max_width = max([extent.width for extent in extents])
+ max_height = max([extent.height for extent in extents])
+ # width should be around scale_x in axis coordinates
+ size = anns[0].get_fontsize() * min(scale_x / max_width,
+ scale_y / max_height)
+ for ann in anns:
+ ann.set_fontsize(size)
+ except AttributeError:
+ # matplotlib < 1.5
+ warnings.warn("Automatic scaling of tree plots requires "
+ "matplotlib 1.5 or higher. Please specify "
+ "fontsize.")
+
+ return anns
+
+ def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0):
+ # need to copy bbox args because matplotib <1.5 modifies them
+ kwargs = dict(bbox=self.bbox_args.copy(), ha='center', va='center',
+ zorder=100 - 10 * depth, xycoords='axes pixels')
+
+ if self.fontsize is not None:
+ kwargs['fontsize'] = self.fontsize
+
+ # offset things by .5 to center them in plot
+ xy = ((node.x + .5) * scale_x, height - (node.y + .5) * scale_y)
+
+ if self.max_depth is None or depth <= self.max_depth:
+ if self.filled:
+ kwargs['bbox']['fc'] = self.get_fill_color(tree,
+ node.tree.node_id)
+ if node.parent is None:
+ # root
+ ax.annotate(node.tree.label, xy, **kwargs)
+ else:
+ xy_parent = ((node.parent.x + .5) * scale_x,
+ height - (node.parent.y + .5) * scale_y)
+ kwargs["arrowprops"] = self.arrow_args
+ ax.annotate(node.tree.label, xy_parent, xy, **kwargs)
+ for child in node.children:
+ self.recurse(child, tree, ax, scale_x, scale_y, height,
+ depth=depth + 1)
- # Specify graph & edge aesthetics
- if leaves_parallel:
- out_file.write('graph [ranksep=equally, splines=polyline] ;\n')
- if rounded:
- out_file.write('edge [fontname=helvetica] ;\n')
- if rotate:
- out_file.write('rankdir=LR ;\n')
+ else:
+ xy_parent = ((node.parent.x + .5) * scale_x,
+ height - (node.parent.y + .5) * scale_y)
+ kwargs["arrowprops"] = self.arrow_args
+ kwargs['bbox']['fc'] = 'grey'
+ ax.annotate("\n (...) \n", xy_parent, xy, **kwargs)
- # Now recurse the tree and add node & edge attributes
- recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)
- # If required, draw leaf nodes at same depth as each other
- if leaves_parallel:
- for rank in sorted(ranks):
- out_file.write("{rank=same ; " +
- "; ".join(r for r in ranks[rank]) + "} ;\n")
- out_file.write("}")
+def export_graphviz(decision_tree, out_file=None, max_depth=None,
+ feature_names=None, class_names=None, label='all',
+ filled=False, leaves_parallel=False, impurity=True,
+ node_ids=False, proportion=False, rotate=False,
+ rounded=False, special_characters=False, precision=3):
+ """Export a decision tree in DOT format.
+
+ This function generates a GraphViz representation of the decision tree,
+ which is then written into `out_file`. Once exported, graphical renderings
+ can be generated using, for example::
+
+ $ dot -Tps tree.dot -o tree.ps (PostScript format)
+ $ dot -Tpng tree.dot -o tree.png (PNG format)
+
+ The sample counts that are shown are weighted with any sample_weights that
+ might be present.
+
+ Read more in the :ref:`User Guide `.
+
+ Parameters
+ ----------
+ decision_tree : decision tree classifier
+ The decision tree to be exported to GraphViz.
+
+ out_file : file object or string, optional (default=None)
+ Handle or name of the output file. If ``None``, the result is
+ returned as a string.
+
+ .. versionchanged:: 0.20
+ Default of out_file changed from "tree.dot" to None.
+
+ max_depth : int, optional (default=None)
+ The maximum depth of the representation. If None, the tree is fully
+ generated.
+
+ feature_names : list of strings, optional (default=None)
+ Names of each of the features.
+
+ class_names : list of strings, bool or None, optional (default=None)
+ Names of each of the target classes in ascending numerical order.
+ Only relevant for classification and not supported for multi-output.
+ If ``True``, shows a symbolic representation of the class name.
+
+ label : {'all', 'root', 'none'}, optional (default='all')
+ Whether to show informative labels for impurity, etc.
+ Options include 'all' to show at every node, 'root' to show only at
+ the top root node, or 'none' to not show at any node.
+
+ filled : bool, optional (default=False)
+ When set to ``True``, paint nodes to indicate majority class for
+ classification, extremity of values for regression, or purity of node
+ for multi-output.
+
+ leaves_parallel : bool, optional (default=False)
+ When set to ``True``, draw all leaf nodes at the bottom of the tree.
+
+ impurity : bool, optional (default=True)
+ When set to ``True``, show the impurity at each node.
+
+ node_ids : bool, optional (default=False)
+ When set to ``True``, show the ID number on each node.
+
+ proportion : bool, optional (default=False)
+ When set to ``True``, change the display of 'values' and/or 'samples'
+ to be proportions and percentages respectively.
+
+ rotate : bool, optional (default=False)
+ When set to ``True``, orient tree left to right rather than top-down.
+
+ rounded : bool, optional (default=False)
+ When set to ``True``, draw node boxes with rounded corners and use
+ Helvetica fonts instead of Times-Roman.
+
+ special_characters : bool, optional (default=False)
+ When set to ``False``, ignore special characters for PostScript
+ compatibility.
+
+ precision : int, optional (default=3)
+ Number of digits of precision for floating point in the values of
+ impurity, threshold and value attributes of each node.
+
+ Returns
+ -------
+ dot_data : string
+ String representation of the input tree in GraphViz dot format.
+ Only returned if ``out_file`` is None.
+
+ .. versionadded:: 0.18
+
+ Examples
+ --------
+ >>> from sklearn.datasets import load_iris
+ >>> from sklearn import tree
+
+ >>> clf = tree.DecisionTreeClassifier()
+ >>> iris = load_iris()
+
+ >>> clf = clf.fit(iris.data, iris.target)
+ >>> tree.export_graphviz(clf) # doctest: +ELLIPSIS
+ 'digraph Tree {...
+ """
+
+ check_is_fitted(decision_tree, 'tree_')
+ own_file = False
+ return_string = False
+ try:
+ if isinstance(out_file, six.string_types):
+ if six.PY3:
+ out_file = open(out_file, "w", encoding="utf-8")
+ else:
+ out_file = open(out_file, "wb")
+ own_file = True
+
+ if out_file is None:
+ return_string = True
+ out_file = six.StringIO()
+
+ exporter = _DOTTreeExporter(
+ out_file=out_file, max_depth=max_depth,
+ feature_names=feature_names, class_names=class_names, label=label,
+ filled=filled, leaves_parallel=leaves_parallel, impurity=impurity,
+ node_ids=node_ids, proportion=proportion, rotate=rotate,
+ rounded=rounded, special_characters=special_characters,
+ precision=precision)
+ exporter.export(decision_tree)
if return_string:
- return out_file.getvalue()
+ return exporter.out_file.getvalue()
finally:
if own_file:
diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py
index c43e0a4f32392..2471914fa44ce 100644
--- a/sklearn/tree/tests/test_export.py
+++ b/sklearn/tree/tests/test_export.py
@@ -1,6 +1,7 @@
"""
Testing for export functions of decision trees (sklearn.tree.export).
"""
+import pytest
from re import finditer, search
@@ -9,7 +10,7 @@
from sklearn.base import is_classifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import GradientBoostingClassifier
-from sklearn.tree import export_graphviz
+from sklearn.tree import export_graphviz, plot_tree
from sklearn.externals.six import StringIO
from sklearn.utils.testing import (assert_in, assert_equal, assert_raises,
assert_less_equal, assert_raises_regex,
@@ -92,13 +93,13 @@ def test_graphviz_toy():
'fontname=helvetica] ;\n' \
'edge [fontname=helvetica] ;\n' \
'0 [label=0 ≤ 0.0
samples = 100.0%
' \
- 'value = [0.5, 0.5]>, fillcolor="#e5813900"] ;\n' \
+ 'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n' \
'1 [label=value = [1.0, 0.0]>, ' \
- 'fillcolor="#e58139ff"] ;\n' \
+ 'fillcolor="#e58139"] ;\n' \
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
'headlabel="True"] ;\n' \
'2 [label=value = [0.0, 1.0]>, ' \
- 'fillcolor="#399de5ff"] ;\n' \
+ 'fillcolor="#399de5"] ;\n' \
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
'headlabel="False"] ;\n' \
'}'
@@ -126,7 +127,7 @@ def test_graphviz_toy():
contents2 = 'digraph Tree {\n' \
'node [shape=box, style="filled", color="black"] ;\n' \
'0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n' \
- 'samples = 6\\nvalue = [3, 3]", fillcolor="#e5813900"] ;\n' \
+ 'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n' \
'1 [label="(...)", fillcolor="#C0C0C0"] ;\n' \
'0 -> 1 ;\n' \
'2 [label="(...)", fillcolor="#C0C0C0"] ;\n' \
@@ -148,21 +149,21 @@ def test_graphviz_toy():
'node [shape=box, style="filled", color="black"] ;\n' \
'0 [label="X[0] <= 0.0\\nsamples = 6\\n' \
'value = [[3.0, 1.5, 0.0]\\n' \
- '[3.0, 1.0, 0.5]]", fillcolor="#e5813900"] ;\n' \
+ '[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n' \
'1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n' \
- '[3, 0, 0]]", fillcolor="#e58139ff"] ;\n' \
+ '[3, 0, 0]]", fillcolor="#e58139"] ;\n' \
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
'headlabel="True"] ;\n' \
'2 [label="X[0] <= 1.5\\nsamples = 3\\n' \
'value = [[0.0, 1.5, 0.0]\\n' \
- '[0.0, 1.0, 0.5]]", fillcolor="#e5813986"] ;\n' \
+ '[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n' \
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
'headlabel="False"] ;\n' \
'3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n' \
- '[0, 1, 0]]", fillcolor="#e58139ff"] ;\n' \
+ '[0, 1, 0]]", fillcolor="#e58139"] ;\n' \
'2 -> 3 ;\n' \
'4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n' \
- '[0.0, 0.0, 0.5]]", fillcolor="#e58139ff"] ;\n' \
+ '[0.0, 0.0, 0.5]]", fillcolor="#e58139"] ;\n' \
'2 -> 4 ;\n' \
'}'
@@ -184,13 +185,13 @@ def test_graphviz_toy():
'edge [fontname=helvetica] ;\n' \
'rankdir=LR ;\n' \
'0 [label="X[0] <= 0.0\\nmse = 1.0\\nsamples = 6\\n' \
- 'value = 0.0", fillcolor="#e5813980"] ;\n' \
+ 'value = 0.0", fillcolor="#f2c09c"] ;\n' \
'1 [label="mse = 0.0\\nsamples = 3\\nvalue = -1.0", ' \
- 'fillcolor="#e5813900"] ;\n' \
+ 'fillcolor="#ffffff"] ;\n' \
'0 -> 1 [labeldistance=2.5, labelangle=-45, ' \
'headlabel="True"] ;\n' \
'2 [label="mse = 0.0\\nsamples = 3\\nvalue = 1.0", ' \
- 'fillcolor="#e58139ff"] ;\n' \
+ 'fillcolor="#e58139"] ;\n' \
'0 -> 2 [labeldistance=2.5, labelangle=45, ' \
'headlabel="False"] ;\n' \
'{rank=same ; 0} ;\n' \
@@ -207,7 +208,7 @@ def test_graphviz_toy():
contents2 = 'digraph Tree {\n' \
'node [shape=box, style="filled", color="black"] ;\n' \
'0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", ' \
- 'fillcolor="#e5813900"] ;\n' \
+ 'fillcolor="#ffffff"] ;\n' \
'}'
@@ -308,3 +309,23 @@ def test_precision():
for finding in finditer(r"<= \d+\.\d+", dot_data):
assert_equal(len(search(r"\.\d+", finding.group()).group()),
precision + 1)
+
+
+def test_plot_tree():
+ # mostly smoke tests
+ pytest.importorskip("matplotlib.pyplot")
+ # Check correctness of export_graphviz
+ clf = DecisionTreeClassifier(max_depth=3,
+ min_samples_split=2,
+ criterion="gini",
+ random_state=2)
+ clf.fit(X, y)
+
+ # Test export code
+ feature_names = ['first feat', 'sepal_width']
+ nodes = plot_tree(clf, feature_names=feature_names)
+ assert len(nodes) == 3
+ assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 0.5\n"
+ "samples = 6\nvalue = [3, 3]")
+ assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
+ assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]"
diff --git a/sklearn/tree/tests/test_reingold_tilford.py b/sklearn/tree/tests/test_reingold_tilford.py
new file mode 100644
index 0000000000000..4cb27ce6effb9
--- /dev/null
+++ b/sklearn/tree/tests/test_reingold_tilford.py
@@ -0,0 +1,54 @@
+import numpy as np
+import pytest
+from sklearn.tree._reingold_tilford import buchheim, Tree
+
+simple_tree = Tree("", 0,
+ Tree("", 1),
+ Tree("", 2))
+
+bigger_tree = Tree("", 0,
+ Tree("", 1,
+ Tree("", 3),
+ Tree("", 4,
+ Tree("", 7),
+ Tree("", 8)
+ ),
+ ),
+ Tree("", 2,
+ Tree("", 5),
+ Tree("", 6)
+ )
+ )
+
+
+@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)])
+def test_buchheim(tree, n_nodes):
+ def walk_tree(draw_tree):
+ res = [(draw_tree.x, draw_tree.y)]
+ for child in draw_tree.children:
+ # parents higher than children:
+ assert child.y == draw_tree.y + 1
+ res.extend(walk_tree(child))
+ if len(draw_tree.children):
+ # these trees are always binary
+ # parents are centered above children
+ assert draw_tree.x == (draw_tree.children[0].x
+ + draw_tree.children[1].x) / 2
+ return res
+
+ layout = buchheim(tree)
+ coordinates = walk_tree(layout)
+ assert len(coordinates) == n_nodes
+ # test that x values are unique per depth / level
+ # we could also do it quicker using defaultdicts..
+ depth = 0
+ while True:
+ x_at_this_depth = []
+ for node in coordinates:
+ if coordinates[1] == depth:
+ x_at_this_depth.append(coordinates[0])
+ if not x_at_this_depth:
+ # reached all leafs
+ break
+ assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth)
+ depth += 1