Skip to content

Commit fb1ab69

Browse files
committed
Add curve visualization classes for matplotlib
1 parent 54c9046 commit fb1ab69

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

geomdl/geomdl_vis/VisMPL.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,55 @@
1515
import matplotlib.pyplot as plt
1616

1717

18+
class VisCurve2D(VisAbstract):
19+
20+
def __init__(self):
21+
super(VisCurve2D, self).__init__()
22+
23+
def render(self):
24+
if not self._points:
25+
return False
26+
27+
cpts = np.array(self._points[0])
28+
crvpts = np.array(self._points[1])
29+
30+
# Draw control points polygon and the curve
31+
plt.figure(figsize=(10.67, 8), dpi=96)
32+
cppolygon, = plt.plot(cpts[:, 0], cpts[:, 1], self._colors[0])
33+
curveplt, = plt.plot(crvpts[:, 0], crvpts[:, 1], self._colors[1])
34+
plt.legend([cppolygon, curveplt], [self._names[0], self._names[1]])
35+
plt.show()
36+
37+
38+
class VisCurve3D(VisAbstract):
39+
40+
def __init__(self):
41+
super(VisCurve3D, self).__init__()
42+
43+
def render(self):
44+
if not self._points:
45+
return False
46+
47+
cpts = np.array(self._points[0])
48+
crvpts = np.array(self._points[1])
49+
50+
# Draw control points polygon and the 3D curve
51+
fig = plt.figure(figsize=(10.67, 8), dpi=96)
52+
ax = fig.gca(projection='3d')
53+
54+
# Plot 3D lines
55+
ax.plot(cpts[:, 0], cpts[:, 1], cpts[:, 2], self._colors[0])
56+
ax.plot(crvpts[:, 0], crvpts[:, 1], crvpts[:, 2], self._colors[1])
57+
58+
# Add legend to 3D plot, @ref: https://stackoverflow.com/a/20505720
59+
scatter1_proxy = matplotlib.lines.Line2D([0], [0], linestyle='none', color=self._colors[0], marker='o')
60+
scatter2_proxy = matplotlib.lines.Line2D([0], [0], linestyle='none', color=self._colors[1], marker='o')
61+
ax.legend([scatter1_proxy, scatter2_proxy], [self._names[0], self._names[1]], numpoints=1)
62+
63+
# Display the 3D plot
64+
plt.show()
65+
66+
1867
class VisSurfWireframe(VisAbstract):
1968

2069
def __init__(self):
@@ -24,7 +73,6 @@ def render(self):
2473
if not self._points:
2574
return False
2675

27-
# Read surface and control points, @ref: https://stackoverflow.com/a/13550615
2876
cpgrid = np.array(self._points[0])
2977
surf = np.array(self._points[1])
3078

0 commit comments

Comments
 (0)