Skip to content

Commit d21839e

Browse files
committed
[plot] fix plotting hold on, return fig/ax/obj dict, 0.3.3
1 parent 76d8a4c commit d21839e

File tree

5 files changed

+102
-63
lines changed

5 files changed

+102
-63
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
* Copyright: (C) Qianqian Fang (2024-2025) <q.fang at neu.edu>, Edward Xu (2024) <xu.ed at northeastern.edu>
66
* License: GNU Public License V3 or later
7-
* Version: 0.3.2
7+
* Version: 0.3.3
88
* URL: [https://pypi.org/project/iso2mesh/](https://pypi.org/project/iso2mesh/)
99
* Github: [https://github.com/NeuroJSON/pyiso2mesh](https://github.com/NeuroJSON/pyiso2mesh)
1010

iso2mesh/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137
barycentricgrid,
138138
)
139139

140-
__version__ = "0.3.2"
140+
__version__ = "0.3.3"
141141
__all__ = [
142142
"advancefront",
143143
"barycentricgrid",

iso2mesh/plot.py

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def plotsurf(node, face, *args, **kwargs):
2929
from matplotlib.colors import Normalize
3030

3131
rngstate = np.random.get_state()
32-
h = []
3332

3433
randseed = int("623F9A9E", 16) + COLOR_OFFSET
3534

@@ -39,11 +38,11 @@ def plotsurf(node, face, *args, **kwargs):
3938

4039
sc = np.random.rand(10, 3)
4140

42-
ax = plt.gca()
41+
ax = _createaxis(*args, **kwargs)
4342

44-
if ax.name != "3d":
45-
plt.figure() # Create a new figure
46-
ax = plt.gcf().add_subplot(projection="3d") # Add 3D axes to the current figure
43+
h = {"fig": [], "ax": [], "obj": []}
44+
h["fig"].append(plt.gcf())
45+
h["ax"].append(ax)
4746

4847
if not "color" in kwargs and not "cmap" in kwargs:
4948
kwargs["cmap"] = plt.get_cmap("jet")
@@ -127,7 +126,7 @@ def plotsurf(node, face, *args, **kwargs):
127126

128127
ax.add_collection3d(patch)
129128
_autoscale_3d(ax, node)
130-
h.append(ax)
129+
h["obj"].append(patch)
131130

132131
np.random.set_state(rngstate)
133132
# plt.axis("equal")
@@ -181,16 +180,15 @@ def plottetra(node, elem, *args, **kwargs):
181180

182181
np.random.seed(randseed)
183182

184-
ax = plt.gca()
183+
ax = _createaxis(*args, **kwargs)
185184

186-
if ax.name != "3d":
187-
plt.figure() # Create a new figure
188-
ax = plt.gcf().add_subplot(projection="3d") # Add 3D axes to the current figure
185+
h = {"fig": [], "ax": [], "obj": []}
186+
h["fig"].append(plt.gcf())
187+
h["ax"].append(ax)
189188

190189
if not "color" in kwargs and not "cmap" in kwargs:
191190
kwargs["cmap"] = plt.get_cmap("jet")
192191

193-
h = []
194192
polydata = []
195193
colormap = []
196194

@@ -224,16 +222,15 @@ def plottetra(node, elem, *args, **kwargs):
224222

225223
patch = Poly3DCollection(polydata, edgecolors="k", **kwargs)
226224
ax.add_collection3d(patch)
227-
228225
_autoscale_3d(ax, node)
229-
h.append(ax)
226+
227+
h["obj"].append(patch)
230228

231229
# Restore RNG state
232230
np.random.set_state(rngstate)
233231

234232
# Return handle if needed
235-
if h:
236-
return h
233+
return h
237234

238235

239236
# _________________________________________________________________________________________________________
@@ -260,14 +257,19 @@ def plotedges(node, edges, *args, **kwargs):
260257
Handles to plotted elements.
261258
"""
262259
edges = np.asarray(edges, order="F") # Flatten in F order if needed
263-
hh = []
264260

265261
if edges.size == 0:
266262
return hh
267263

268264
edlen = edges.shape[0]
269265
rng_state = np.random.get_state()
270266

267+
ax = _createaxis(*args, **kwargs)
268+
269+
hh = {"fig": [], "ax": [], "obj": []}
270+
hh["fig"].append(plt.gcf())
271+
hh["ax"].append(ax)
272+
271273
if edges.ndim == 1 or edges.shape[1] == 1:
272274
# Loop: NaN-separated index list
273275
randseed = int("623F9A9E", 16) + COLOR_OFFSET
@@ -296,7 +298,7 @@ def plotedges(node, edges, *args, **kwargs):
296298
*args,
297299
**kwargs,
298300
)
299-
hh.append(h)
301+
hh["obj"].append(h)
300302
seghead = i + 1
301303
else:
302304
from mpl_toolkits.mplot3d.art3d import Line3DCollection
@@ -318,13 +320,12 @@ def plotedges(node, edges, *args, **kwargs):
318320
ax.add_collection3d(h)
319321
_autoscale_3d(ax, node)
320322

321-
hh.append(h)
322323
else:
323324
x = node[:, 0].flatten()
324325
y = node[:, 1].flatten()
325326
h = plt.plot(x[edges.T], y[edges.T], *args, **kwargs)
326327

327-
hh.append(h)
328+
hh["obj"].append(h)
328329

329330
np.random.set_state(rng_state)
330331
return hh
@@ -335,7 +336,7 @@ def plotedges(node, edges, *args, **kwargs):
335336

336337
def plotmesh(node, *args, **kwargs):
337338
"""
338-
plotmesh(node, face, elem, opt) → hm
339+
handles = plotmesh(node, face, elem, selector, ...)
339340
Plot surface and volumetric meshes in 3D.
340341
Converts 1-based MATLAB indices in `face` and `elem` to 0-based.
341342
Supports optional selector strings and stylistic options.
@@ -376,18 +377,18 @@ def plotmesh(node, *args, **kwargs):
376377
elem = a
377378

378379
extraarg = {}
379-
if len(opt) > 1 and len(opt) % 2 == 0:
380-
extraarg = dict(zip(opt[::2], opt[1::2]))
380+
if "hold" in kwargs:
381+
extraarg["hold"] = kwargs["hold"]
381382

382-
handles = []
383+
ax = _createaxis(True, *args, **kwargs)
383384

384-
ax = kwargs.get("parent", None)
385+
handles = {"fig": [], "ax": [], "obj": []}
386+
handles["fig"].append(plt.gcf())
387+
handles["ax"].append(ax)
385388

386-
if ax is None:
387-
fig = plt.figure()
388-
ax = fig.add_subplot(111, projection="3d")
389-
else:
390-
del kwargs["parent"]
389+
for extraopt in ["hold", "parent", "subplot"]:
390+
if extraopt in kwargs:
391+
del kwargs[extraopt]
391392

392393
# Plot points if no face/elem
393394
if face is None and elem is None:
@@ -400,17 +401,17 @@ def plotmesh(node, *args, **kwargs):
400401
if getattr(idx, "size", None) == 0:
401402
print("Warning: nothing to plot")
402403
return None
403-
ax.plot(x[idx], y[idx], z[idx], **kwargs)
404+
(h,) = ax.plot(x[idx], y[idx], z[idx], *opt, **kwargs)
405+
handles["obj"].append(h)
404406
_autoscale_3d(ax, node)
405407
if not "hold" in extraarg or not extraarg["hold"] or extraarg["hold"] == "off":
406408
plt.show(block=False)
407-
return ax
409+
return handles
408410

409411
# Plot surface mesh
410412
if face is not None:
411413
if isinstance(face, list):
412-
ax = plotsurf(node, face, opt, *args, **kwargs)
413-
handles.append(ax)
414+
handles = plotsurf(node, face, opt, *args, **kwargs)
414415
else:
415416
c0 = meshcentroid(node[:, :3], face[:, :3])
416417
x, y, z = c0[:, 0], c0[:, 1], c0[:, 2]
@@ -422,8 +423,7 @@ def plotmesh(node, *args, **kwargs):
422423
if getattr(idx, "size", None) == 0:
423424
print("Warning: nothing to plot")
424425
return None
425-
ax = plotsurf(node, face[idx, :], opt, *args, **kwargs)
426-
handles.append(ax)
426+
handles = plotsurf(node, face[idx, :], opt, *args, **kwargs)
427427

428428
# Plot tetrahedral mesh
429429
if elem is not None:
@@ -437,13 +437,12 @@ def plotmesh(node, *args, **kwargs):
437437
if getattr(idx, "size", None) == 0:
438438
print("Warning: nothing to plot")
439439
return None
440-
ax = plottetra(node, elem[idx, :], opt, *args, **kwargs)
441-
handles.append(ax)
440+
handles = plottetra(node, elem[idx, :], opt, *args, **kwargs)
442441

443442
if not "hold" in extraarg or not extraarg["hold"] or extraarg["hold"] == "off":
444443
plt.show(block=False)
445444

446-
return handles if len(handles) > 1 else handles[0]
445+
return handles
447446

448447

449448
def _autoscale_3d(ax, points):
@@ -453,3 +452,27 @@ def _autoscale_3d(ax, points):
453452
ax.set_zlim([z.min(), z.max()])
454453
boxas = [x.max() - x.min(), y.max() - y.min(), z.max() - z.min()]
455454
ax.set_box_aspect(boxas)
455+
456+
457+
def _createaxis(*args, **kwargs):
458+
subplotid = kwargs.get("subplot", 111)
459+
docreate = False if len(args) == 0 else args[0]
460+
461+
if "parent" in kwargs:
462+
ax = kwargs["parent"]
463+
if isinstance(ax, dict):
464+
ax = ax["ax"][-1]
465+
elif isinstance(ax, list):
466+
ax = ax[-1]
467+
elif not docreate and len(plt.get_fignums()) > 0 and len(plt.gcf().axes) > 0:
468+
ax = plt.gcf().axes[-1]
469+
else:
470+
if docreate:
471+
plt.figure()
472+
ax = plt.gcf().add_subplot(subplotid, projection="3d")
473+
474+
if ax.name != "3d":
475+
plt.figure() # Create a new figure
476+
ax = plt.gcf().add_subplot(subplotid, projection="3d")
477+
478+
return ax

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name="iso2mesh",
88
packages=["iso2mesh"],
9-
version="0.3.2",
9+
version="0.3.3",
1010
license='GPLv3+',
1111
description="Image-based 3D Surface and Volumetric Mesh Generator",
1212
long_description=readme,

test/run_test.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,78 +1106,94 @@ def __init__(self, *args, **kwargs):
11061106
)
11071107
self.fc = volface(self.el)[0]
11081108

1109+
def test_plotmesh_node(self):
1110+
ax = plotmesh(self.no, "ro", hold="on")
1111+
xy = ax["obj"][-1]._xy
1112+
expected_fc = [
1113+
[1.0, -1.0],
1114+
[2.0, -1.0],
1115+
[1.0, 0.0],
1116+
[2.0, 0.0],
1117+
[1.0, -1.0],
1118+
[2.0, -1.0],
1119+
[1.0, 0.0],
1120+
[2.0, 0.0],
1121+
[1.0, -1.0],
1122+
[2.0, -1.0],
1123+
[1.0, 0.0],
1124+
[2.0, 0.0],
1125+
]
1126+
1127+
self.assertEqual(xy.tolist(), expected_fc)
1128+
11091129
def test_plotmesh_face(self):
1110-
patch = plotmesh(self.no, self.fc, "hold", True)
1111-
facecolors = np.array(patch[0].get_facecolors())
1130+
ax = plotmesh(self.no, self.fc, hold="on")
1131+
facecolors = np.array(ax["obj"][-1].get_facecolors())
11121132
expected_fc = [9.0, 8.6803, 9.0, 20.0]
11131133

11141134
self.assertEqual(len(facecolors), 20)
11151135
self.assertEqual(np.round(np.sum(facecolors, axis=0), 4).tolist(), expected_fc)
11161136

11171137
def test_plotmesh_elem(self):
1118-
patch = plotmesh(self.no, self.el, "hold", True)
1119-
facecolors = np.array(patch[0].get_facecolors())
1138+
ax = plotmesh(self.no, self.el, hold="on")
1139+
facecolors = np.array(ax["obj"][-1].get_facecolors())
11201140
expected_fc = [9.0, 8.6803, 9.0, 20.0]
11211141

11221142
self.assertEqual(len(facecolors), 20)
11231143
self.assertEqual(np.round(np.sum(facecolors, axis=0), 4).tolist(), expected_fc)
11241144

11251145
def test_plotmesh_elemlabel(self):
1126-
patch = plotmesh(
1146+
ax = plotmesh(
11271147
self.no,
11281148
np.hstack((self.el, np.ones(self.el.shape[0], dtype=int).reshape(-1, 1))),
1129-
"hold",
1130-
True,
1149+
hold="on",
11311150
)
1132-
facecolors = np.array(patch[0].get_facecolors())
1151+
facecolors = np.array(ax["obj"][-1].get_facecolors())
11331152
expected_fc = (1, 4)
11341153

11351154
self.assertEqual(len(facecolors), 20)
11361155
self.assertEqual(np.unique(facecolors, axis=0).shape, expected_fc)
11371156

11381157
def test_plotmesh_facelabel(self):
1139-
patch = plotmesh(
1158+
ax = plotmesh(
11401159
self.no,
11411160
np.hstack((self.fc, np.array([1, 2] * 10).reshape(-1, 1))),
11421161
None,
1143-
"hold",
1144-
True,
1162+
hold="on",
11451163
)
1146-
facecolors = np.array(patch[0].get_facecolors())
1164+
facecolors = np.array(ax["obj"][-1].get_facecolors())
11471165
expected_fc = (2, 4)
11481166

11491167
self.assertEqual(len(facecolors), 20)
11501168
self.assertEqual(np.unique(facecolors, axis=0).shape, expected_fc)
11511169

11521170
def test_plotmesh_elemnodeval(self):
1153-
patch = plotmesh(self.no[:, [0, 1, 2, 0]], self.el, "hold", True)
1154-
facecolors = np.array(patch[0].get_facecolors())
1171+
ax = plotmesh(self.no[:, [0, 1, 2, 0]], self.el, hold="on")
1172+
facecolors = np.array(ax["obj"][-1].get_facecolors())
11551173
expected_fc = [8.0, 10.4074, 8.0, 20.0]
11561174

11571175
self.assertEqual(len(facecolors), 20)
11581176
self.assertEqual(np.round(np.sum(facecolors, axis=0), 4).tolist(), expected_fc)
11591177

11601178
def test_plotmesh_facenodeval(self):
1161-
patch = plotmesh(self.no[:, [0, 1, 2, 0]], self.fc, "z < 3", "hold", True)
1162-
facecolors = np.array(patch[0].get_facecolors())
1179+
ax = plotmesh(self.no[:, [0, 1, 2, 0]], self.fc, "z < 3", hold="on")
1180+
facecolors = np.array(ax["obj"][-1].get_facecolors())
11631181
expected_fc = [7.0, 8.6728, 7.0, 18.0]
11641182

11651183
self.assertEqual(len(facecolors), 18)
11661184
self.assertEqual(np.round(np.sum(facecolors, axis=0), 4).tolist(), expected_fc)
11671185

11681186
def test_plotmesh_selector(self):
1169-
patch = plotmesh(
1170-
self.no[:, [0, 1, 2, 0]], self.fc, "(z < 3) & (x < 2)", "hold", True
1171-
)
1172-
facecolors = np.array(patch[0].get_facecolors())
1187+
ax = plotmesh(self.no[:, [0, 1, 2, 0]], self.fc, "(z < 3) & (x < 2)", hold="on")
1188+
facecolors = np.array(ax["obj"][-1].get_facecolors())
11731189
expected_fc = [4.8877, 5.0, 4.451, 14.0]
11741190

11751191
self.assertEqual(len(facecolors), 14)
11761192
self.assertEqual(np.round(np.sum(facecolors, axis=0), 4).tolist(), expected_fc)
11771193

11781194
def test_plotmesh_elemselector(self):
1179-
patch = plotmesh(self.no, self.fc, "z < 2.5", "hold", True)
1180-
facecolors = np.array(patch[0].get_facecolors())
1195+
ax = plotmesh(self.no, self.fc, "z < 2.5", hold="on")
1196+
facecolors = np.array(ax["obj"][-1].get_facecolors())
11811197
expected_fc = [3.9102, 4.0, 2.9608, 10.0]
11821198

11831199
self.assertEqual(len(facecolors), 10)

0 commit comments

Comments
 (0)