Skip to content

Commit 2edc731

Browse files
authored
Merge pull request #3842 from samuelgarcia/tdc_peeler_improve
Plot drifting templates
2 parents fb3bda8 + 76a163b commit 2edc731

File tree

4 files changed

+140
-2
lines changed

4 files changed

+140
-2
lines changed

src/spikeinterface/sortingcomponents/matching/tdc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def __init__(
152152
self.sparse_templates_array_static = None
153153

154154
# interpolation bins edges
155-
156155
self.interpolation_time_bins_s = []
157156
self.interpolation_time_bin_edges_s = []
158157
for segment_index, parent_segment in enumerate(recording._recording_segments):
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from __future__ import annotations
2+
3+
from probeinterface import ProbeGroup
4+
5+
from .base import BaseWidget, to_attr
6+
from .utils import get_unit_colors
7+
from spikeinterface.core.sortinganalyzer import SortingAnalyzer
8+
9+
from .unit_templates import UnitTemplatesWidget
10+
from ..core import Templates
11+
12+
13+
class DriftingTemplatesWidget(BaseWidget):
14+
"""
15+
Plot a drifting templates object to explore motion
16+
17+
Parameters
18+
----------
19+
drifting_templates :
20+
A drifting templates object
21+
scale : float, default: 1
22+
Scale factor for the waveforms/templates (matplotlib backend)
23+
"""
24+
25+
def __init__(
26+
self,
27+
drifting_templates: SortingAnalyzer,
28+
scale=1,
29+
backend=None,
30+
**backend_kwargs,
31+
):
32+
self.drifting_templates = drifting_templates
33+
34+
data_plot = dict(
35+
drifting_templates=drifting_templates,
36+
)
37+
38+
BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs)
39+
40+
def plot_ipywidgets(self, data_plot, **backend_kwargs):
41+
import matplotlib.pyplot as plt
42+
import ipywidgets.widgets as widgets
43+
from IPython.display import display
44+
from .utils_ipywidgets import check_ipywidget_backend, UnitSelector
45+
46+
check_ipywidget_backend()
47+
48+
# self.next_data_plot = data_plot.copy()
49+
self.drifting_templates = data_plot["drifting_templates"]
50+
51+
cm = 1 / 2.54
52+
53+
width_cm = backend_kwargs["width_cm"]
54+
height_cm = backend_kwargs["height_cm"]
55+
56+
ratios = [0.15, 0.85]
57+
58+
with plt.ioff():
59+
output = widgets.Output()
60+
with output:
61+
fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm))
62+
plt.show()
63+
64+
unit_ids = self.drifting_templates.unit_ids
65+
self.unit_selector = UnitSelector(unit_ids)
66+
self.unit_selector.value = list(unit_ids)[:1]
67+
68+
arr = self.drifting_templates.templates_array_moved
69+
70+
self.slider = widgets.IntSlider(
71+
orientation="horizontal",
72+
value=arr.shape[0] // 2,
73+
min=0,
74+
max=arr.shape[0] - 1,
75+
readout=False,
76+
continuous_update=True,
77+
layout=widgets.Layout(width=f"100%"),
78+
)
79+
80+
self.widget = widgets.AppLayout(
81+
center=fig.canvas,
82+
left_sidebar=self.unit_selector,
83+
pane_widths=ratios + [0],
84+
footer=self.slider,
85+
)
86+
87+
self._update_ipywidget()
88+
89+
self.unit_selector.observe(self._change_unit, names="value", type="change")
90+
self.slider.observe(self._change_displacement, names="value", type="change")
91+
92+
if backend_kwargs["display"]:
93+
display(self.widget)
94+
95+
def _change_unit(self, change=None):
96+
self._update_ipywidget(keep_lims=False)
97+
98+
def _change_displacement(self, change=None):
99+
self._update_ipywidget(keep_lims=True)
100+
101+
def _update_ipywidget(self, keep_lims=False):
102+
if keep_lims:
103+
xlim = self.ax.get_xlim()
104+
ylim = self.ax.get_ylim()
105+
106+
self.ax.clear()
107+
unit_ids = self.unit_selector.value
108+
109+
displacement_index = self.slider.value
110+
111+
templates_array = self.drifting_templates.templates_array_moved[displacement_index, :, :, :]
112+
templates = Templates(
113+
templates_array,
114+
self.drifting_templates.sampling_frequency,
115+
self.drifting_templates.nbefore,
116+
is_scaled=self.drifting_templates.is_scaled,
117+
sparsity_mask=None,
118+
channel_ids=self.drifting_templates.channel_ids,
119+
unit_ids=self.drifting_templates.unit_ids,
120+
probe=self.drifting_templates.probe,
121+
)
122+
123+
UnitTemplatesWidget(
124+
templates, unit_ids=unit_ids, scale=5, plot_legend=False, backend="matplotlib", ax=self.ax, same_axis=True
125+
)
126+
127+
displacement = self.drifting_templates.displacements[displacement_index]
128+
self.ax.set_title(f"{displacement_index}:{displacement} - untis:{unit_ids}")
129+
130+
if keep_lims:
131+
self.ax.set_xlim(xlim)
132+
self.ax.set_ylim(ylim)
133+
134+
fig = self.ax.get_figure()
135+
fig.canvas.draw()
136+
fig.canvas.flush_events()

src/spikeinterface/widgets/unit_waveforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __init__(
167167

168168
# get templates
169169
if isinstance(sorting_analyzer_or_templates, Templates):
170-
templates = sorting_analyzer_or_templates.templates_array
170+
templates = sorting_analyzer_or_templates.select_units(unit_ids).templates_array
171171
nbefore = sorting_analyzer_or_templates.nbefore
172172
self.templates_ext = None
173173
templates_shading = None

src/spikeinterface/widgets/widget_list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .amplitudes import AmplitudesWidget
1010
from .autocorrelograms import AutoCorrelogramsWidget
1111
from .crosscorrelograms import CrossCorrelogramsWidget
12+
from .drift_templates import DriftingTemplatesWidget
1213
from .isi_distribution import ISIDistributionWidget
1314
from .motion import DriftRasterMapWidget, MotionWidget, MotionInfoWidget
1415
from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget
@@ -45,6 +46,7 @@
4546
ConfusionMatrixWidget,
4647
ComparisonCollisionBySimilarityWidget,
4748
CrossCorrelogramsWidget,
49+
DriftingTemplatesWidget,
4850
DriftRasterMapWidget,
4951
ISIDistributionWidget,
5052
LocationsWidget,
@@ -124,6 +126,7 @@
124126
plot_confusion_matrix = ConfusionMatrixWidget
125127
plot_comparison_collision_by_similarity = ComparisonCollisionBySimilarityWidget
126128
plot_crosscorrelograms = CrossCorrelogramsWidget
129+
plot_drifting_templates = DriftingTemplatesWidget
127130
plot_drift_raster_map = DriftRasterMapWidget
128131
plot_isi_distribution = ISIDistributionWidget
129132
plot_locations = LocationsWidget

0 commit comments

Comments
 (0)