Skip to content

Commit efd37cb

Browse files
authored
Merge pull request SpikeInterface#334 from jakeswann1/main
Allow plot_probe not to plot on axes, but just return polycollections
2 parents c96510c + 4604a7e commit efd37cb

File tree

1 file changed

+89
-38
lines changed

1 file changed

+89
-38
lines changed

src/probeinterface/plotting.py

Lines changed: 89 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,80 @@
1212
from .utils import get_auto_lims
1313

1414

15+
def create_probe_polygons(
16+
probe,
17+
contacts_colors: list | None = None,
18+
contacts_values: np.ndarray | None = None,
19+
cmap: str = "viridis",
20+
contacts_kargs: dict = {},
21+
probe_shape_kwargs: dict = {},
22+
):
23+
"""Create PolyCollection objects for a Probe.
24+
25+
Parameters
26+
----------
27+
probe : Probe
28+
The probe object
29+
contacts_colors : matplotlib color | None, default: None
30+
The color of the contacts
31+
contacts_values : np.ndarray | None, default: None
32+
Values to color the contacts with
33+
cmap : str, default: "viridis"
34+
A colormap color
35+
contacts_kargs : dict, default: {}
36+
Dict with kwargs for contacts (e.g. alpha, edgecolor, lw)
37+
probe_shape_kwargs : dict, default: {}
38+
Dict with kwargs for probe shape (e.g. alpha, edgecolor, lw)
39+
40+
Returns
41+
-------
42+
poly : PolyCollection
43+
The polygon collection for contacts
44+
poly_contour : PolyCollection | None
45+
The polygon collection for the probe shape
46+
"""
47+
if probe.ndim == 2:
48+
from matplotlib.collections import PolyCollection
49+
50+
Collection = PolyCollection
51+
elif probe.ndim == 3:
52+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
53+
54+
Collection = Poly3DCollection
55+
else:
56+
raise ValueError(f"Unexpected probe.ndim: {probe.ndim}")
57+
58+
_probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3)
59+
_probe_shape_kwargs.update(probe_shape_kwargs)
60+
61+
_contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5)
62+
_contacts_kargs.update(contacts_kargs)
63+
64+
n = probe.get_contact_count()
65+
66+
if contacts_colors is None and contacts_values is None:
67+
contacts_colors = ["orange"] * n
68+
elif contacts_colors is not None:
69+
contacts_colors = contacts_colors
70+
elif contacts_values is not None:
71+
contacts_colors = None
72+
73+
vertices = probe.get_contact_vertices()
74+
poly = Collection(vertices, color=contacts_colors, **_contacts_kargs)
75+
76+
if contacts_values is not None:
77+
poly.set_array(contacts_values)
78+
poly.set_cmap(cmap)
79+
80+
# probe shape
81+
poly_contour = None
82+
planar_contour = probe.probe_planar_contour
83+
if planar_contour is not None:
84+
poly_contour = Collection([planar_contour], **_probe_shape_kwargs)
85+
86+
return poly, poly_contour
87+
88+
1589
def plot_probe(
1690
probe,
1791
ax=None,
@@ -74,11 +148,6 @@ def plot_probe(
74148
"""
75149
import matplotlib.pyplot as plt
76150

77-
if probe.ndim == 2:
78-
from matplotlib.collections import PolyCollection
79-
elif probe.ndim == 3:
80-
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
81-
82151
if ax is None:
83152
if probe.ndim == 2:
84153
fig, ax = plt.subplots()
@@ -89,32 +158,25 @@ def plot_probe(
89158
else:
90159
fig = ax.get_figure()
91160

92-
_probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3)
93-
_probe_shape_kwargs.update(probe_shape_kwargs)
94-
95-
_contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5)
96-
_contacts_kargs.update(contacts_kargs)
97-
98-
n = probe.get_contact_count()
99-
100-
if contacts_colors is None and contacts_values is None:
101-
contacts_colors = ["orange"] * n
102-
elif contacts_colors is not None:
103-
contacts_colors = contacts_colors
104-
elif contacts_values is not None:
105-
contacts_colors = None
161+
# Create collections (contacts, probe shape)
162+
poly, poly_contour = create_probe_polygons(
163+
probe,
164+
contacts_colors=contacts_colors,
165+
contacts_values=contacts_values,
166+
cmap=cmap,
167+
contacts_kargs=contacts_kargs,
168+
probe_shape_kwargs=probe_shape_kwargs,
169+
)
106170

107-
vertices = probe.get_contact_vertices()
171+
# Add collections to the axis
108172
if probe.ndim == 2:
109-
poly = PolyCollection(vertices, color=contacts_colors, **_contacts_kargs)
110173
ax.add_collection(poly)
174+
if poly_contour is not None:
175+
ax.add_collection(poly_contour)
111176
elif probe.ndim == 3:
112-
poly = Poly3DCollection(vertices, color=contacts_colors, **_contacts_kargs)
113177
ax.add_collection3d(poly)
114-
115-
if contacts_values is not None:
116-
poly.set_array(contacts_values)
117-
poly.set_cmap(cmap)
178+
if poly_contour is not None:
179+
ax.add_collection3d(poly_contour)
118180

119181
if show_channel_on_click:
120182
assert probe.ndim == 2, "show_channel_on_click works only for ndim=2"
@@ -125,22 +187,11 @@ def on_press(event):
125187
fig.canvas.mpl_connect("button_press_event", on_press)
126188
fig.canvas.mpl_connect("button_release_event", on_release)
127189

128-
# probe shape
129-
planar_contour = probe.probe_planar_contour
130-
if planar_contour is not None:
131-
if probe.ndim == 2:
132-
poly_contour = PolyCollection([planar_contour], **_probe_shape_kwargs)
133-
ax.add_collection(poly_contour)
134-
elif probe.ndim == 3:
135-
poly_contour = Poly3DCollection([planar_contour], **_probe_shape_kwargs)
136-
ax.add_collection3d(poly_contour)
137-
else:
138-
poly_contour = None
139-
140190
if text_on_contact is not None:
141191
text_on_contact = np.asarray(text_on_contact)
142192
assert text_on_contact.size == probe.get_contact_count()
143193

194+
n = probe.get_contact_count()
144195
if with_contact_id or with_device_index or text_on_contact is not None:
145196
if probe.ndim == 3:
146197
raise NotImplementedError("Channel index is 2d only")

0 commit comments

Comments
 (0)