Skip to content

Commit c13ba26

Browse files
authored
[MRG] Matplotlib tree plotting (scikit-learn#9251)
* add reingold tillford tree layout algorithm * add first silly implementation of matplotlib based plotting for trees * object oriented design for export_graphviz so it can be extended * add class for mlp export * add colors * separately scale x and y, add arrowheads, fix strings * implement max_depth * don't use alpha for coloring because it makes boxes transparent * remove unused variables * vertical center of boxes * fix/simplify newline trimming * somewhere in the middle of stuff trying to get rid of scalex, scaley * remove "find_longest_child" for now, fix tests * make scalex and scaley internal, and ax local. render everything once to get the bbox sizes, then again to actually plot it with known extents. * add some margin to the max bbox width * add _BaseTreeExporter baseclass * add docstring to plot_tree * use data coordinates so we can put the plot in a subplot, remove some hacks. * remove scalex, scaley, add automatic font size * use rendered stuff for setting limits (well nearly there) * import plot_tree into tree module * set limits before font size adjustment? * add tree plotting via matplotlib to iris example and to docs * pep8 fix * skip doctest on plot_tree because matplotlib is not installed on all CI machines * redo everything in axis pixel coordinates re-introduce scalex, scaley add max_extents to tree to get tree size before plotting * fix max-depth parent node positioning and don't consider deep nodes in layouting * consider height in fontsize computation in case someone gave us a very flat figure * fix error when max_depth is None * add docstring for tree plotting fontsize * starting on jnothman's review * renaming fixes * whatsnew for tree plotting * clear axes prior to doing anything. * fix doctests * skip matplotlib doctest * trying to debug circle failure * trying to show full traceback * more print debugging * remove debugging crud * hack around matplotlib <1.5 issues * copy bbox args because old matplotlib is weird. * pep8 fixes * add explicit boxstyle * more pep8 * even more pep8 * add comment about matplotlib version requirement * remove redundant file * add whatsnew entry that the merge lost * fix merge issue * more merge issues * whitespace ... * remove doctest skip to see what's happening * added some simple invariance tests buchheim function * refactor ___init__ into superclass * added some tests of plot_tree * put skip back in, fix typo, fix versionadded number * remove unused parameters special_characters and parallel_leaves from mpl plotting * rename tests to test_reingold_tilford * added license header from pymag-trees repo * remove duplicate test file.
1 parent 0f94f29 commit c13ba26

File tree

9 files changed

+821
-210
lines changed

9 files changed

+821
-210
lines changed

doc/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ endif
1313
# Internal variables.
1414
PAPEROPT_a4 = -D latex_paper_size=a4
1515
PAPEROPT_letter = -D latex_paper_size=letter
16-
ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS)\
16+
ALLSPHINXOPTS = -T -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS)\
1717
$(EXAMPLES_PATTERN_OPTS) .
1818

1919

doc/modules/tree.rst

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,20 @@ Using the Iris dataset, we can construct a tree as follows::
124124
>>> clf = tree.DecisionTreeClassifier()
125125
>>> clf = clf.fit(iris.data, iris.target)
126126

127-
Once trained, we can export the tree in `Graphviz
127+
Once trained, you can plot the tree with the plot_tree function::
128+
129+
130+
>>> tree.plot_tree(clf.fit(iris.data, iris.target)) # doctest: +SKIP
131+
132+
.. figure:: ../auto_examples/tree/images/sphx_glr_plot_iris_002.png
133+
:target: ../auto_examples/tree/plot_iris.html
134+
:scale: 75
135+
:align: center
136+
137+
We can also export the tree in `Graphviz
128138
<https://www.graphviz.org/>`_ format using the :func:`export_graphviz`
129-
exporter. If you use the `conda <https://conda.io/>`_ package manager, the graphviz binaries
139+
exporter. If you use the `conda <https://conda.io>`_ package manager, the graphviz binaries
140+
130141
and the python package can be installed with
131142

132143
conda install python-graphviz

doc/whats_new/v0.21.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ Support for Python 3.4 and below has been officially dropped.
6262

6363
:mod:`sklearn.tree`
6464
...................
65+
- Decision Trees can now be plotted with matplotlib using
66+
:func:`tree.export.plot_tree` without relying on the ``dot`` library,
67+
removing a hard-to-install dependency.
68+
:issue:`8508` by `Andreas Müller`_.
69+
6570
- |Feature| ``get_n_leaves()`` and ``get_depth()`` have been added to
6671
:class:`tree.BaseDecisionTree` and consequently all estimators based
6772
on it, including :class:`tree.DecisionTreeClassifier`,

examples/tree/plot_iris.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
For each pair of iris features, the decision tree learns decision
1212
boundaries made of combinations of simple thresholding rules inferred from
1313
the training samples.
14+
15+
We also show the tree structure of a model built on all of the features.
1416
"""
1517
print(__doc__)
1618

1719
import numpy as np
1820
import matplotlib.pyplot as plt
1921

2022
from sklearn.datasets import load_iris
21-
from sklearn.tree import DecisionTreeClassifier
23+
from sklearn.tree import DecisionTreeClassifier, plot_tree
2224

2325
# Parameters
2426
n_classes = 3
@@ -62,4 +64,8 @@
6264
plt.suptitle("Decision surface of a decision tree using paired features")
6365
plt.legend(loc='lower right', borderpad=0, handletextpad=0)
6466
plt.axis("tight")
67+
68+
plt.figure()
69+
clf = DecisionTreeClassifier().fit(iris.data, iris.target)
70+
plot_tree(clf, filled=True)
6571
plt.show()

sklearn/tree/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from .tree import DecisionTreeRegressor
88
from .tree import ExtraTreeClassifier
99
from .tree import ExtraTreeRegressor
10-
from .export import export_graphviz
10+
from .export import export_graphviz, plot_tree
1111

1212
__all__ = ["DecisionTreeClassifier", "DecisionTreeRegressor",
13-
"ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz"]
13+
"ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz",
14+
"plot_tree"]

sklearn/tree/_reingold_tilford.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# taken from https://github.com/llimllib/pymag-trees/blob/master/buchheim.py
2+
# with slight modifications
3+
4+
# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
5+
# Version 2, December 2004
6+
#
7+
# Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
8+
#
9+
# Everyone is permitted to copy and distribute verbatim or modified
10+
# copies of this license document, and changing it is allowed as long
11+
# as the name is changed.
12+
#
13+
# DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
14+
# TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
15+
16+
# 0. You just DO WHAT THE FUCK YOU WANT TO.
17+
18+
19+
import numpy as np
20+
21+
22+
class DrawTree(object):
23+
def __init__(self, tree, parent=None, depth=0, number=1):
24+
self.x = -1.
25+
self.y = depth
26+
self.tree = tree
27+
self.children = [DrawTree(c, self, depth + 1, i + 1)
28+
for i, c
29+
in enumerate(tree.children)]
30+
self.parent = parent
31+
self.thread = None
32+
self.mod = 0
33+
self.ancestor = self
34+
self.change = self.shift = 0
35+
self._lmost_sibling = None
36+
# this is the number of the node in its group of siblings 1..n
37+
self.number = number
38+
39+
def left(self):
40+
return self.thread or len(self.children) and self.children[0]
41+
42+
def right(self):
43+
return self.thread or len(self.children) and self.children[-1]
44+
45+
def lbrother(self):
46+
n = None
47+
if self.parent:
48+
for node in self.parent.children:
49+
if node == self:
50+
return n
51+
else:
52+
n = node
53+
return n
54+
55+
def get_lmost_sibling(self):
56+
if not self._lmost_sibling and self.parent and self != \
57+
self.parent.children[0]:
58+
self._lmost_sibling = self.parent.children[0]
59+
return self._lmost_sibling
60+
lmost_sibling = property(get_lmost_sibling)
61+
62+
def __str__(self):
63+
return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod)
64+
65+
def __repr__(self):
66+
return self.__str__()
67+
68+
def max_extents(self):
69+
extents = [c.max_extents() for c in self. children]
70+
extents.append((self.x, self.y))
71+
return np.max(extents, axis=0)
72+
73+
74+
def buchheim(tree):
75+
dt = first_walk(DrawTree(tree))
76+
min = second_walk(dt)
77+
if min < 0:
78+
third_walk(dt, -min)
79+
return dt
80+
81+
82+
def third_walk(tree, n):
83+
tree.x += n
84+
for c in tree.children:
85+
third_walk(c, n)
86+
87+
88+
def first_walk(v, distance=1.):
89+
if len(v.children) == 0:
90+
if v.lmost_sibling:
91+
v.x = v.lbrother().x + distance
92+
else:
93+
v.x = 0.
94+
else:
95+
default_ancestor = v.children[0]
96+
for w in v.children:
97+
first_walk(w)
98+
default_ancestor = apportion(w, default_ancestor, distance)
99+
# print("finished v =", v.tree, "children")
100+
execute_shifts(v)
101+
102+
midpoint = (v.children[0].x + v.children[-1].x) / 2
103+
104+
w = v.lbrother()
105+
if w:
106+
v.x = w.x + distance
107+
v.mod = v.x - midpoint
108+
else:
109+
v.x = midpoint
110+
return v
111+
112+
113+
def apportion(v, default_ancestor, distance):
114+
w = v.lbrother()
115+
if w is not None:
116+
# in buchheim notation:
117+
# i == inner; o == outer; r == right; l == left; r = +; l = -
118+
vir = vor = v
119+
vil = w
120+
vol = v.lmost_sibling
121+
sir = sor = v.mod
122+
sil = vil.mod
123+
sol = vol.mod
124+
while vil.right() and vir.left():
125+
vil = vil.right()
126+
vir = vir.left()
127+
vol = vol.left()
128+
vor = vor.right()
129+
vor.ancestor = v
130+
shift = (vil.x + sil) - (vir.x + sir) + distance
131+
if shift > 0:
132+
move_subtree(ancestor(vil, v, default_ancestor), v, shift)
133+
sir = sir + shift
134+
sor = sor + shift
135+
sil += vil.mod
136+
sir += vir.mod
137+
sol += vol.mod
138+
sor += vor.mod
139+
if vil.right() and not vor.right():
140+
vor.thread = vil.right()
141+
vor.mod += sil - sor
142+
else:
143+
if vir.left() and not vol.left():
144+
vol.thread = vir.left()
145+
vol.mod += sir - sol
146+
default_ancestor = v
147+
return default_ancestor
148+
149+
150+
def move_subtree(wl, wr, shift):
151+
subtrees = wr.number - wl.number
152+
# print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees,
153+
# 'shift', shift)
154+
# print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees
155+
wr.change -= shift / subtrees
156+
wr.shift += shift
157+
wl.change += shift / subtrees
158+
wr.x += shift
159+
wr.mod += shift
160+
161+
162+
def execute_shifts(v):
163+
shift = change = 0
164+
for w in v.children[::-1]:
165+
# print("shift:", w, shift, w.change)
166+
w.x += shift
167+
w.mod += shift
168+
change += w.change
169+
shift += w.shift + change
170+
171+
172+
def ancestor(vil, v, default_ancestor):
173+
# the relevant text is at the bottom of page 7 of
174+
# "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al,
175+
# (2002)
176+
# http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.16.8757&rep=rep1&type=pdf
177+
if vil.ancestor in v.parent.children:
178+
return vil.ancestor
179+
else:
180+
return default_ancestor
181+
182+
183+
def second_walk(v, m=0, depth=0, min=None):
184+
v.x += m
185+
v.y = depth
186+
187+
if min is None or v.x < min:
188+
min = v.x
189+
190+
for w in v.children:
191+
min = second_walk(w, m + v.mod, depth + 1, min)
192+
193+
return min
194+
195+
196+
class Tree(object):
197+
def __init__(self, label="", node_id=-1, *children):
198+
self.label = label
199+
self.node_id = node_id
200+
if children:
201+
self.children = children
202+
else:
203+
self.children = []

0 commit comments

Comments
 (0)