Skip to content

Commit d81833a

Browse files
authored
Merge pull request #30 from benabel/master
Implement svg output in notebooks #29
2 parents a2fd4ec + c4a2577 commit d81833a

File tree

4 files changed

+159
-14
lines changed

4 files changed

+159
-14
lines changed

binarytree/__init__.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from dataclasses import dataclass
66
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
77

8-
from graphviz import Digraph, nohtml
98
from pkg_resources import get_distribution
109

1110
from binarytree.exceptions import (
11+
GraphvizImportError,
1212
NodeIndexError,
1313
NodeModifyError,
1414
NodeNotFoundError,
@@ -18,6 +18,15 @@
1818
TreeHeightError,
1919
)
2020

21+
try:
22+
from graphviz import Digraph, nohtml
23+
24+
GRAPHVIZ_INSTALLED = True
25+
except ImportError:
26+
GRAPHVIZ_INSTALLED = False
27+
Digraph = Any
28+
from binarytree.layout import generate_svg
29+
2130
__version__ = get_distribution("binarytree").version
2231

2332
LEFT_FIELD = "left"
@@ -463,29 +472,32 @@ def _repr_svg_(self) -> str:
463472
464473
.. _Jupyter notebooks: https://jupyter.org
465474
"""
466-
# noinspection PyProtectedMember
467-
return str(self.graphviz()._repr_svg_())
475+
if GRAPHVIZ_INSTALLED:
476+
# noinspection PyProtectedMember
477+
return str(self.graphviz()._repr_svg_())
478+
else:
479+
return generate_svg(self.values) # pragma: no cover
468480

469481
def graphviz(self, *args: Any, **kwargs: Any) -> Digraph:
470482
"""Return a graphviz.Digraph_ object representing the binary tree.
471-
472483
This method's positional and keyword arguments are passed directly into the
473484
the Digraph's **__init__** method.
474-
475485
:return: graphviz.Digraph_ object representing the binary tree.
476-
486+
:raise binarytree.exceptions.GraphvizImportError: If graphviz is not installed
477487
.. code-block:: python
478-
479488
>>> from binarytree import tree
480489
>>>
481490
>>> t = tree()
482491
>>>
483492
>>> graph = t.graphviz() # Generate a graphviz object
484493
>>> graph.body # Get the DOT body
485494
>>> graph.render() # Render the graph
486-
487495
.. _graphviz.Digraph: https://graphviz.readthedocs.io/en/stable/api.html#digraph
488496
"""
497+
if not GRAPHVIZ_INSTALLED:
498+
raise GraphvizImportError(
499+
"Can't use graphviz method if graphviz module is not installed"
500+
)
489501
if "node_attr" not in kwargs:
490502
kwargs["node_attr"] = {
491503
"shape": "record",
@@ -494,20 +506,14 @@ def graphviz(self, *args: Any, **kwargs: Any) -> Digraph:
494506
"fillcolor": "lightgray",
495507
"fontcolor": "black",
496508
}
497-
498509
digraph = Digraph(*args, **kwargs)
499-
500510
for node in self:
501511
node_id = str(id(node))
502-
503512
digraph.node(node_id, nohtml(f"<l>|<v> {node.value}|<r>"))
504-
505513
if node.left is not None:
506514
digraph.edge(f"{node_id}:l", f"{id(node.left)}:v")
507-
508515
if node.right is not None:
509516
digraph.edge(f"{node_id}:r", f"{id(node.right)}:v")
510-
511517
return digraph
512518

513519
def pprint(self, index: bool = False, delimiter: str = "-") -> None:

binarytree/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,7 @@ class NodeValueError(BinaryTreeError):
2828

2929
class TreeHeightError(BinaryTreeError):
3030
"""Tree height was invalid."""
31+
32+
33+
class GraphvizImportError(BinaryTreeError):
34+
"""graphviz module is not installed"""

binarytree/layout.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
""" Module containing layout related algorithms."""
2+
from typing import List, Tuple, Union
3+
4+
5+
def _get_coords(
6+
values: List[Union[float, int, None]]
7+
) -> Tuple[
8+
List[Tuple[int, int, Union[float, int, None]]], List[Tuple[int, int, int, int]]
9+
]:
10+
"""Generate the coordinates used for rendering the nodes and edges.
11+
12+
node and edges are stored as tuples in the form node: (x, y, label) and
13+
edge: (x1, y1, x2, y2)
14+
15+
Each coordinate is relative y is the depth, x is the position of the node
16+
on a level from left to right 0 to 2**depth -1
17+
18+
:param values: Values of the binary tree.
19+
:type values: list of ints
20+
:return: nodes and edges list
21+
:rtype: two lists of tuples
22+
23+
"""
24+
x = 0
25+
y = 0
26+
nodes = []
27+
edges = []
28+
29+
# root node
30+
nodes.append((x, y, values[0]))
31+
# append other nodes and their edges
32+
y += 1
33+
for value in values[1:]:
34+
if value is not None:
35+
nodes.append((x, y, value))
36+
edges.append((x // 2, y - 1, x, y))
37+
x += 1
38+
# check if level is full
39+
if x == 2 ** y:
40+
x = 0
41+
y += 1
42+
return nodes, edges
43+
44+
45+
def generate_svg(values: List[Union[float, int, None]]) -> str:
46+
"""Generate a svg image from a binary tree
47+
48+
A simple layout is used based on a perfect tree of same height in which all
49+
leaves would be regularly spaced.
50+
51+
:param values: Values of the binary tree.
52+
:type values: list of ints
53+
:return: the svg image of the tree.
54+
:rtype: str
55+
"""
56+
node_size = 16.0
57+
stroke_width = 1.5
58+
gutter = 0.5
59+
x_scale = (2 + gutter) * node_size
60+
y_scale = 3.0 * node_size
61+
62+
# retrieve relative coordinates
63+
nodes, edges = _get_coords(values)
64+
y_min = min([n[1] for n in nodes])
65+
y_max = max([n[1] for n in nodes])
66+
67+
# generate the svg string
68+
svg = f"""
69+
<svg width="{x_scale * 2**y_max}" height="{y_scale * (2 + y_max)}"
70+
xmlns="http://www.w3.org/2000/svg">
71+
<style>
72+
.bt-label {{
73+
font: 300 {node_size}px sans-serif;;
74+
text-align: center;
75+
dominant-baseline: middle;
76+
text-anchor: middle;
77+
}}
78+
.bt-node {{
79+
fill: lightgray;
80+
stroke-width: {stroke_width};
81+
}}
82+
83+
</style>
84+
<g stroke="#111">
85+
"""
86+
# scales
87+
88+
def scalex(x: int, y: int) -> float:
89+
depth = y_max - y
90+
# offset
91+
x = 2 ** (depth + 1) * x + 2 ** depth - 1
92+
return 1 + node_size + x_scale * x / 2
93+
94+
def scaley(y: int) -> float:
95+
return float(y_scale * (1 + y - y_min))
96+
97+
# edges
98+
def svg_edge(x1: float, y1: float, x2: float, y2: float) -> str:
99+
"""Generate svg code for an edge"""
100+
return f"""<line x1="{x1}" x2="{x2}" y1="{y1}" y2="{y2}"/>"""
101+
102+
for a in edges:
103+
x1, y1, x2, y2 = a
104+
svg += svg_edge(scalex(x1, y1), scaley(y1), scalex(x2, y2), scaley(y2))
105+
106+
# nodes
107+
def svg_node(x: float, y: float, label: str = "") -> str:
108+
"""Generate svg code for a node and his label"""
109+
return f"""
110+
<circle class="bt-node" cx="{x}" cy="{y}" r="{node_size}"/>
111+
<text class="bt-label" x="{x}" y="{y}">{label}</text>"""
112+
113+
for n in nodes:
114+
x, y, label = n
115+
svg += svg_node(scalex(x, y), scaley(y), str(label))
116+
117+
svg += "</g></svg>"
118+
return svg

tests/test_layout.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import xml.etree.ElementTree as ET
2+
3+
from binarytree.layout import _get_coords, generate_svg
4+
5+
6+
def test_get_coords():
7+
values = [0, 6, 5, None, 1, 4, 2]
8+
assert _get_coords(values) == (
9+
[(0, 0, 0), (0, 1, 6), (1, 1, 5), (1, 2, 1), (2, 2, 4), (3, 2, 2)],
10+
[(0, 0, 0, 1), (0, 0, 1, 1), (0, 1, 1, 2), (1, 1, 2, 2), (1, 1, 3, 2)],
11+
)
12+
13+
14+
def test_svg():
15+
svg = generate_svg([0, 1, 2])
16+
svg_tree = ET.fromstring(svg)
17+
assert svg_tree.tag == "{http://www.w3.org/2000/svg}svg"

0 commit comments

Comments
 (0)