Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MRG] Matplotlib tree plotting #9251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 73 commits into from
Oct 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
1881ece
add reingold tillford tree layout algorithm
amueller Jun 29, 2017
bd7d022
add first silly implementation of matplotlib based plotting for trees
amueller Jun 29, 2017
287c1d2
object oriented design for export_graphviz so it can be extended
amueller Jun 29, 2017
0d5e3e2
add class for mlp export
amueller Jun 29, 2017
8f52d87
add colors
amueller Jun 29, 2017
4a5fe67
separately scale x and y, add arrowheads, fix strings
amueller Jun 29, 2017
ddb6c16
implement max_depth
amueller Jun 29, 2017
fed2d1d
don't use alpha for coloring because it makes boxes transparent
amueller Jun 29, 2017
5145ed2
remove unused variables
amueller Jun 29, 2017
8663ad7
vertical center of boxes
amueller Jun 29, 2017
d750deb
fix/simplify newline trimming
amueller Jun 29, 2017
d3c17ea
somewhere in the middle of stuff
amueller Jun 30, 2017
823ce1f
remove "find_longest_child" for now, fix tests
amueller Jul 6, 2017
0229d5d
make scalex and scaley internal, and ax local.
amueller Jul 6, 2017
a2df69e
add some margin to the max bbox width
amueller Jul 6, 2017
5212f59
add _BaseTreeExporter baseclass
amueller Jul 7, 2017
60c0b73
add docstring to plot_tree
amueller Jul 7, 2017
3b4a730
use data coordinates so we can put the plot in a subplot, remove some…
amueller Jul 7, 2017
a30f634
remove scalex, scaley, add automatic font size
amueller Jul 10, 2017
27a29ac
use rendered stuff for setting limits (well nearly there)
amueller Jul 10, 2017
c2e6d31
Merge branch 'master' into matplotlib_tree_plotting
amueller Jul 12, 2017
538d257
import plot_tree into tree module
amueller Jul 13, 2017
c6ecbb2
set limits before font size adjustment?
amueller Jul 13, 2017
fc7bdbe
add tree plotting via matplotlib to iris example and to docs
amueller Jul 13, 2017
9d672ab
pep8 fix
amueller Jul 13, 2017
1c8b8d6
skip doctest on plot_tree because matplotlib is not installed on all …
amueller Jul 13, 2017
474c557
redo everything in axis pixel coordinates
amueller Jul 14, 2017
4c97f37
fix max-depth
amueller Jul 14, 2017
b31e7ec
consider height in fontsize computation
amueller Jul 14, 2017
9f36648
fix error when max_depth is None
amueller Jul 15, 2017
697aede
Merge branch 'master' into matplotlib_tree_plotting
amueller Aug 29, 2017
d0b2c95
Merge branch 'master' into matplotlib_tree_plotting
amueller Nov 21, 2017
bac2c51
Merge branch 'matplotlib_tree_plotting' of github.com:amueller/scikit…
amueller Nov 21, 2017
752135e
add docstring for tree plotting fontsize
amueller Nov 21, 2017
9d7a3fa
Merge branch 'master' into matplotlib_tree_plotting
amueller Jun 28, 2018
fe92f74
starting on jnothman's review
amueller Jun 28, 2018
83a4a2e
renaming fixes
amueller Jun 28, 2018
a1ba414
whatsnew for tree plotting
amueller Jun 28, 2018
df6620d
clear axes prior to doing anything.
amueller Jun 28, 2018
02aeeab
fix doctests
amueller Jun 28, 2018
072a66b
skip matplotlib doctest
amueller Jun 28, 2018
88cc060
trying to debug circle failure
amueller Jun 28, 2018
9eb0549
trying to show full traceback
amueller Jun 29, 2018
59714c0
Merge branch 'master' into matplotlib_tree_plotting
amueller Jun 29, 2018
61dd70a
more print debugging
amueller Jun 29, 2018
5d84743
remove debugging crud
amueller Jun 29, 2018
cf4a620
hack around matplotlib <1.5 issues
amueller Jun 29, 2018
db61249
copy bbox args because old matplotlib is weird.
amueller Jun 29, 2018
5beedf2
pep8 fixes
amueller Jun 29, 2018
64d47ac
add explicit boxstyle
amueller Jun 29, 2018
cbfce1c
more pep8
amueller Jun 29, 2018
f92aca1
even more pep8
amueller Jun 29, 2018
d6cc6ad
add comment about matplotlib version requirement
amueller Jul 14, 2018
daae2e2
remove redundant file
amueller Aug 20, 2018
82b5459
Merge branch 'master' into matplotlib_tree_plotting
amueller Sep 28, 2018
7d76ca9
add whatsnew entry that the merge lost
amueller Sep 28, 2018
1a9a874
Merge branch 'master' into matplotlib_tree_plotting
amueller Sep 28, 2018
1937f3d
fix merge issue
amueller Oct 1, 2018
db464a3
more merge issues
amueller Oct 1, 2018
042865a
whitespace ...
amueller Oct 1, 2018
1a16750
remove doctest skip to see what's happening
amueller Oct 3, 2018
6f2d597
added some simple invariance tests buchheim function
amueller Oct 3, 2018
6803f96
refactor
amueller Oct 3, 2018
7b62316
added some tests of plot_tree
amueller Oct 3, 2018
817b2bb
put skip back in, fix typo, fix versionadded number
amueller Oct 3, 2018
55c7d36
remove unused parameters special_characters and parallel_leaves from …
amueller Oct 3, 2018
c228ae0
Merge branch 'master' into matplotlib_tree_plotting
amueller Oct 7, 2018
69e7c40
Merge branch 'master' into matplotlib_tree_plotting
amueller Oct 10, 2018
fc756d0
Merge branch 'matplotlib_tree_plotting' of github.com:amueller/scikit…
amueller Oct 10, 2018
9554a87
rename tests to test_reingold_tilford
amueller Oct 10, 2018
becfa07
Merge branch 'master' into matplotlib_tree_plotting
amueller Oct 10, 2018
435d217
added license header from pymag-trees repo
amueller Oct 10, 2018
98e8d5a
remove duplicate test file.
amueller Oct 11, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ endif
# Internal variables.
PAPEROPT_a4 = -D latex_paper_size=a4
PAPEROPT_letter = -D latex_paper_size=letter
ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS)\
ALLSPHINXOPTS = -T -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS)\
$(EXAMPLES_PATTERN_OPTS) .


Expand Down
15 changes: 13 additions & 2 deletions doc/modules/tree.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,20 @@ Using the Iris dataset, we can construct a tree as follows::
>>> clf = tree.DecisionTreeClassifier()
>>> clf = clf.fit(iris.data, iris.target)

Once trained, we can export the tree in `Graphviz
Once trained, you can plot the tree with the plot_tree function::


>>> tree.plot_tree(clf.fit(iris.data, iris.target)) # doctest: +SKIP

.. figure:: ../auto_examples/tree/images/sphx_glr_plot_iris_002.png
:target: ../auto_examples/tree/plot_iris.html
:scale: 75
:align: center

We can also export the tree in `Graphviz
<https://www.graphviz.org/>`_ format using the :func:`export_graphviz`
exporter. If you use the `conda <https://conda.io/>`_ package manager, the graphviz binaries
exporter. If you use the `conda <https://conda.io>`_ package manager, the graphviz binaries

and the python package can be installed with

conda install python-graphviz
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ Support for Python 3.4 and below has been officially dropped.

:mod:`sklearn.tree`
...................
- Decision Trees can now be plotted with matplotlib using
:func:`tree.export.plot_tree` without relying on the ``dot`` library,
removing a hard-to-install dependency.
:issue:`8508` by `Andreas Müller`_.

- |Feature| ``get_n_leaves()`` and ``get_depth()`` have been added to
:class:`tree.BaseDecisionTree` and consequently all estimators based
on it, including :class:`tree.DecisionTreeClassifier`,
Expand Down
8 changes: 7 additions & 1 deletion examples/tree/plot_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
For each pair of iris features, the decision tree learns decision
boundaries made of combinations of simple thresholding rules inferred from
the training samples.

We also show the tree structure of a model built on all of the features.
"""
print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeClassifier, plot_tree

# Parameters
n_classes = 3
Expand Down Expand Up @@ -62,4 +64,8 @@
plt.suptitle("Decision surface of a decision tree using paired features")
plt.legend(loc='lower right', borderpad=0, handletextpad=0)
plt.axis("tight")

plt.figure()
clf = DecisionTreeClassifier().fit(iris.data, iris.target)
plot_tree(clf, filled=True)
plt.show()
5 changes: 3 additions & 2 deletions sklearn/tree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from .tree import DecisionTreeRegressor
from .tree import ExtraTreeClassifier
from .tree import ExtraTreeRegressor
from .export import export_graphviz
from .export import export_graphviz, plot_tree

__all__ = ["DecisionTreeClassifier", "DecisionTreeRegressor",
"ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz"]
"ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz",
"plot_tree"]
203 changes: 203 additions & 0 deletions sklearn/tree/_reingold_tilford.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# taken from https://github.com/llimllib/pymag-trees/blob/master/buchheim.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this file not include a more extensive license?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the license.

# with slight modifications

# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
# Version 2, December 2004
#
# Copyright (C) 2004 Sam Hocevar <[email protected]>
#
# Everyone is permitted to copy and distribute verbatim or modified
# copies of this license document, and changing it is allowed as long
# as the name is changed.
#
# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
# TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION

# 0. You just DO WHAT THE FUCK YOU WANT TO.


import numpy as np


class DrawTree(object):
def __init__(self, tree, parent=None, depth=0, number=1):
self.x = -1.
self.y = depth
self.tree = tree
self.children = [DrawTree(c, self, depth + 1, i + 1)
for i, c
in enumerate(tree.children)]
self.parent = parent
self.thread = None
self.mod = 0
self.ancestor = self
self.change = self.shift = 0
self._lmost_sibling = None
# this is the number of the node in its group of siblings 1..n
self.number = number

def left(self):
return self.thread or len(self.children) and self.children[0]

def right(self):
return self.thread or len(self.children) and self.children[-1]

def lbrother(self):
n = None
if self.parent:
for node in self.parent.children:
if node == self:
return n
else:
n = node
return n

def get_lmost_sibling(self):
if not self._lmost_sibling and self.parent and self != \
self.parent.children[0]:
self._lmost_sibling = self.parent.children[0]
return self._lmost_sibling
lmost_sibling = property(get_lmost_sibling)

def __str__(self):
return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod)

def __repr__(self):
return self.__str__()

def max_extents(self):
extents = [c.max_extents() for c in self. children]
extents.append((self.x, self.y))
return np.max(extents, axis=0)


def buchheim(tree):
dt = first_walk(DrawTree(tree))
min = second_walk(dt)
if min < 0:
third_walk(dt, -min)
return dt


def third_walk(tree, n):
tree.x += n
for c in tree.children:
third_walk(c, n)


def first_walk(v, distance=1.):
if len(v.children) == 0:
if v.lmost_sibling:
v.x = v.lbrother().x + distance
else:
v.x = 0.
else:
default_ancestor = v.children[0]
for w in v.children:
first_walk(w)
default_ancestor = apportion(w, default_ancestor, distance)
# print("finished v =", v.tree, "children")
execute_shifts(v)

midpoint = (v.children[0].x + v.children[-1].x) / 2

w = v.lbrother()
if w:
v.x = w.x + distance
v.mod = v.x - midpoint
else:
v.x = midpoint
return v


def apportion(v, default_ancestor, distance):
w = v.lbrother()
if w is not None:
# in buchheim notation:
# i == inner; o == outer; r == right; l == left; r = +; l = -
vir = vor = v
vil = w
vol = v.lmost_sibling
sir = sor = v.mod
sil = vil.mod
sol = vol.mod
while vil.right() and vir.left():
vil = vil.right()
vir = vir.left()
vol = vol.left()
vor = vor.right()
vor.ancestor = v
shift = (vil.x + sil) - (vir.x + sir) + distance
if shift > 0:
move_subtree(ancestor(vil, v, default_ancestor), v, shift)
sir = sir + shift
sor = sor + shift
sil += vil.mod
sir += vir.mod
sol += vol.mod
sor += vor.mod
if vil.right() and not vor.right():
vor.thread = vil.right()
vor.mod += sil - sor
else:
if vir.left() and not vol.left():
vol.thread = vir.left()
vol.mod += sir - sol
default_ancestor = v
return default_ancestor


def move_subtree(wl, wr, shift):
subtrees = wr.number - wl.number
# print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees,
# 'shift', shift)
# print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees
wr.change -= shift / subtrees
wr.shift += shift
wl.change += shift / subtrees
wr.x += shift
wr.mod += shift


def execute_shifts(v):
shift = change = 0
for w in v.children[::-1]:
# print("shift:", w, shift, w.change)
w.x += shift
w.mod += shift
change += w.change
shift += w.shift + change


def ancestor(vil, v, default_ancestor):
# the relevant text is at the bottom of page 7 of
# "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al,
# (2002)
# http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.16.8757&rep=rep1&type=pdf
if vil.ancestor in v.parent.children:
return vil.ancestor
else:
return default_ancestor


def second_walk(v, m=0, depth=0, min=None):
v.x += m
v.y = depth

if min is None or v.x < min:
min = v.x

for w in v.children:
min = second_walk(w, m + v.mod, depth + 1, min)

return min


class Tree(object):
def __init__(self, label="", node_id=-1, *children):
self.label = label
self.node_id = node_id
if children:
self.children = children
else:
self.children = []
Loading