Skip to content

Commit 3f3e28d

Browse files
Merge pull request #741 from MouseLand/jacob/output_docs
Jacob/output docs
2 parents b82e562 + 5c2178e commit 3f3e28d

File tree

2 files changed

+230
-20
lines changed

2 files changed

+230
-20
lines changed

kilosort/io.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,129 @@ def remove_bad_channels(probe, bad_channels):
212212
def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
213213
data_dtype=None, save_extra_vars=False,
214214
save_preprocessed_copy=False):
215+
"""Save sorting results to disk in a format readable by Phy.
216+
217+
Parameters
218+
----------
219+
st : np.ndarray
220+
3-column array of peak time (in samples), template, and amplitude for
221+
each spike.
222+
clu : np.ndarray
223+
1D vector of cluster ids indicating which spike came from which cluster,
224+
same shape as `st[:,0]`.
225+
tF : torch.Tensor
226+
PC features for each spike, with shape
227+
(n_spikes, nearest_chans, n_pcs)
228+
Wall : torch.Tensor
229+
PC feature representation of spike waveforms for each cluster, with shape
230+
(n_clusters, n_channels, n_pcs).
231+
probe : dict; optional.
232+
A Kilosort4 probe dictionary, as returned by `kilosort.io.load_probe`.
233+
ops : dict
234+
Dictionary storing settings and results for all algorithmic steps.
235+
imin : int
236+
Minimum sample index used by BinaryRWFile, exported spike times will
237+
be shifted forward by this number.
238+
results_dir : pathlib.Path; optional.
239+
Directory where results should be saved.
240+
data_dtype : str or type; optional.
241+
dtype of data in binary file, like `'int32'` or `np.uint16`. By default,
242+
dtype is assumed to be `'int16'`.
243+
save_extra_vars : bool; default=False.
244+
If True, save tF and Wall to disk along with copies of st, clu and
245+
amplitudes with no postprocessing applied.
246+
save_preprocessed_copy : bool; default=False.
247+
If True, save a pre-processed copy of the data (including drift
248+
correction) to `temp_wh.dat` in the results directory and format Phy
249+
output to use that copy of the data.
250+
251+
Returns
252+
-------
253+
results_dir : pathlib.Path.
254+
Directory where results are saved.
255+
similar_templates : np.ndarray.
256+
Similarity score between each pair of clusters, computed as correlation
257+
between clusters. Shape (n_clusters, n_clusters).
258+
is_ref : np.ndarray.
259+
1D boolean array with shape (n_clusters,) indicating whether each
260+
cluster is refractory.
261+
est_contam_rate : np.ndarray.
262+
Contamination rate for each cluster, computed as fraction of refractory
263+
period violations relative to expectation based on a Poisson process.
264+
Shape (n_clusters,).
265+
kept_spikes : np.ndarray.
266+
Boolean mask with shape (n_spikes,) that is False for spikes that were
267+
removed by `kilosort.postprocessing.remove_duplicate_spikes`
268+
and True otherwise.
269+
270+
Notes
271+
-----
272+
The following files will be saved in `results_dir`. Note that 'template'
273+
here does *not* refer to the universal or learned templates used for spike
274+
detection, as it did in some past versions of Kilosort. Instead, it refers
275+
to the average spike waveform (after whitening, filtering, and drift
276+
correction) for all spikes assigned to each cluster, which are template-like
277+
in shape. We use the term 'template' anyway for this section because that is
278+
how they are treated in Phy. Elsewhere in the Kilosort4 code, we would refer
279+
to these as 'clusters.'
280+
281+
amplitudes.npy : shape (n_spikes,)
282+
Per-spike amplitudes, computed as the L2 norm of the PC features
283+
for each spike.
284+
channel_map.npy : shape (n_channels,)
285+
Same as probe['chanMap']. Integer indices into rows of binary file
286+
that map the data to the contacts listed in the probe file.
287+
channel_positions.npy : shape (n_channels,2)
288+
Same as probe['xc'] and probe['yc'], but combined in a single array.
289+
Indicates x- and y- positions (in microns) of probe contacts.
290+
cluster_Amplitude.tsv : shape (n_templates,)
291+
Per-template amplitudes, computed as the L2 norm of the template.
292+
cluster_ContamPct.tsv : shape (n_templates,)
293+
Contamination rate for each template, computed as fraction of refractory
294+
period violations relative to expectation based on a Poisson process.
295+
cluster_KSLabel.tsv : shape (n_templates,)
296+
Label indicating whether each template is 'mua' (multi-unit activity)
297+
or 'good' (refractory).
298+
cluster_group.tsv : shape (n_templates,)
299+
Same as `cluster_KSLabel.tsv`.
300+
kept_spikes.npy : shape (n_spikes,)
301+
Boolean mask that is False for spikes that were removed by
302+
`kilosort.postprocessing.remove_duplicate_spikes` and True otherwise.
303+
ops.npy : shape N/A
304+
Dictionary containing a number of state variables saved throughout
305+
the sorting process (see `run_kilosort`). We recommend loading with
306+
`kilosort.io.load_ops`.
307+
params.py : shape N/A
308+
Settings used by Phy, like data location and sampling rate.
309+
similar_templates.npy : shape (n_templates, n_templates)
310+
Similarity score between each pair of templates, computed as correlation
311+
between templates.
312+
spike_clusters.npy : shape (n_spikes,)
313+
For each spike, integer indicating which template it was assigned to.
314+
spike_templates.npy : shape (n_spikes,2)
315+
Same as `spike_clusters.npy`.
316+
spike_positions.npy : shape (n_spikes,2)
317+
Estimated (x,y) position relative to probe geometry, in microns,
318+
for each spike.
319+
spike_times.npy : shape (n_spikes,)
320+
Sample index of the waveform peak for each spike.
321+
templates.npy : shape (n_templates, nt, n_channels)
322+
Full time x channels template shapes.
323+
templates_ind.npy : shape (n_templates, n_channels)
324+
Channel indices on which each cluster is defined. For KS4, this is always
325+
all channels, but Phy requires this file.
326+
whitening_mat.npy : shape (n_channels, n_channels)
327+
Matrix applied to data for whitening.
328+
whitening_mat_inv.npy : shape (n_channels, n_channels)
329+
Inverse of whitening matrix.
330+
whitening_mat_dat.npy : shape (n_channels, n_channels)
331+
matrix applied to data for whitening. Currently this is the same as
332+
`whitening_mat.npy`, but was added because the latter was previously
333+
altered before saving for Phy, so this ensured the original was still
334+
saved. It's kept in for now because we may need to change the version
335+
used by Phy again in the future.
336+
337+
"""
215338

216339
if results_dir is None:
217340
results_dir = ops['data_dir'].joinpath('kilosort4')

kilosort/run_kilosort.py

Lines changed: 107 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,38 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
9999
100100
Returns
101101
-------
102-
ops, st, clu, tF, Wall, similar_templates, is_ref, est_contam_rate, kept_spikes
103-
Description TODO
102+
ops : dict
103+
Dictionary storing settings and results for all algorithmic steps.
104+
st : np.ndarray
105+
3-column array of peak time (in samples), template, and amplitude for
106+
each spike.
107+
clu : np.ndarray
108+
1D vector of cluster ids indicating which spike came from which cluster,
109+
same shape as `st[:,0]`.
110+
tF : torch.Tensor
111+
PC features for each spike, with shape
112+
(n_spikes, nearest_chans, n_pcs)
113+
Wall : torch.Tensor
114+
PC feature representation of spike waveforms for each cluster, with shape
115+
(n_clusters, n_channels, n_pcs).
116+
similar_templates : np.ndarray.
117+
Similarity score between each pair of clusters, computed as correlation
118+
between clusters. Shape (n_clusters, n_clusters).
119+
is_ref : np.ndarray.
120+
1D boolean array with shape (n_clusters,) indicating whether each
121+
cluster is refractory.
122+
est_contam_rate : np.ndarray.
123+
Contamination rate for each cluster, computed as fraction of refractory
124+
period violations relative to expectation based on a Poisson process.
125+
Shape (n_clusters,).
126+
kept_spikes : np.ndarray.
127+
Boolean mask with shape (n_spikes,) that is False for spikes that were
128+
removed by `kilosort.postprocessing.remove_duplicate_spikes`
129+
and True otherwise.
130+
131+
Notes
132+
-----
133+
For documentation of saved files, see `kilosort.io.save_to_phy`.
104134
105135
"""
106136

@@ -447,8 +477,12 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
447477
Returns
448478
-------
449479
ops : dict
480+
Dictionary storing settings and results for all algorithmic steps.
450481
bfile : kilosort.io.BinaryFiltered
451482
Wrapped file object for handling data.
483+
st0 : np.ndarray.
484+
Intermediate spike times variable with 6 columns. This is only used
485+
for generating the 'Drift Scatter' plot through the GUI.
452486
453487
"""
454488

@@ -493,7 +527,7 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
493527

494528

495529
def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
496-
"""Run spike sorting algorithm and save intermediate results to `ops`.
530+
"""Detect spikes via template deconvolution.
497531
498532
Parameters
499533
----------
@@ -511,14 +545,17 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
511545
Returns
512546
-------
513547
st : np.ndarray
514-
1D vector of spike times for all clusters.
548+
3-column array of peak time (in samples), template, and amplitude for
549+
each spike.
515550
clu : np.ndarray
516551
1D vector of cluster ids indicating which spike came from which cluster,
517552
same shape as `st`.
518-
tF : np.ndarray
519-
TODO
520-
Wall : np.ndarray
521-
TODO
553+
tF : torch.Tensor
554+
PC features for each spike, with shape
555+
(n_spikes, nearest_chans, n_pcs)
556+
Wall : torch.Tensor
557+
PC feature representation of spike waveforms for each cluster, with shape
558+
(n_clusters, n_channels, n_pcs).
522559
523560
"""
524561

@@ -564,6 +601,37 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
564601

565602

566603
def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None):
604+
"""Cluster spikes using graph-based methods.
605+
606+
Parameters
607+
----------
608+
st : np.ndarray
609+
3-column array of peak time (in samples), template, and amplitude for
610+
each spike.
611+
tF : torch.Tensor
612+
PC features for each spike, with shape
613+
(n_spikes, nearest_chans, n_pcs)
614+
ops : dict
615+
Dictionary storing settings and results for all algorithmic steps.
616+
device : torch.device
617+
Indicates whether `pytorch` operations should be run on cpu or gpu.
618+
bfile : kilosort.io.BinaryFiltered
619+
Wrapped file object for handling data.
620+
tic0 : float; default=np.nan.
621+
Start time of `run_kilosort`.
622+
progress_bar : TODO; optional.
623+
Informs `tqdm` package how to report progress, type unclear.
624+
625+
Returns
626+
-------
627+
clu : np.ndarray
628+
1D vector of cluster ids indicating which spike came from which cluster,
629+
same shape as `st`.
630+
Wall : torch.Tensor
631+
PC feature representation of spike waveforms for each cluster, with shape
632+
(n_clusters, n_channels, n_pcs).
633+
634+
"""
567635
tic = time.time()
568636
logger.info(' ')
569637
logger.info('Final clustering')
@@ -603,21 +671,25 @@ def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan,
603671
results_dir : pathlib.Path
604672
Directory where results should be saved.
605673
st : np.ndarray
606-
1D vector of spike times for all clusters.
674+
3-column array of peak time (in samples), template, and amplitude for
675+
each spike.
607676
clu : np.ndarray
608677
1D vector of cluster ids indicating which spike came from which cluster,
609-
same shape as `st`.
610-
tF : np.ndarray
611-
TODO
612-
Wall : np.ndarray
613-
TODO
678+
same shape as `st[:,0]`.
679+
tF : torch.Tensor
680+
PC features for each spike, with shape
681+
(n_spikes, nearest_chans, n_pcs)
682+
Wall : torch.Tensor
683+
PC feature representation of spike waveforms for each cluster, with shape
684+
(n_clusters, n_channels, n_pcs).
614685
imin : int
615686
Minimum sample index used by BinaryRWFile, exported spike times will
616687
be shifted forward by this number.
617688
tic0 : float; default=np.nan.
618689
Start time of `run_kilosort`.
619690
save_extra_vars : bool; default=False.
620-
If True, save tF and Wall to disk after sorting.
691+
If True, save tF and Wall to disk along with copies of st, clu and
692+
amplitudes with no postprocessing applied.
621693
save_preprocessed_copy : bool; default=False.
622694
If True, save a pre-processed copy of the data (including drift
623695
correction) to `temp_wh.dat` in the results directory and format Phy
@@ -626,11 +698,26 @@ def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan,
626698
Returns
627699
-------
628700
ops : dict
629-
similar_templates : np.ndarray
630-
is_ref : np.ndarray
631-
est_contam_rate : np.ndarray
632-
kept_spikes : np.ndarray
633-
701+
Dictionary storing settings and results for all algorithmic steps.
702+
similar_templates : np.ndarray.
703+
Similarity score between each pair of clusters, computed as correlation
704+
between clusters. Shape (n_clusters, n_clusters).
705+
is_ref : np.ndarray.
706+
1D boolean array with shape (n_clusters,) indicating whether each
707+
cluster is refractory.
708+
est_contam_rate : np.ndarray.
709+
Contamination rate for each cluster, computed as fraction of refractory
710+
period violations relative to expectation based on a Poisson process.
711+
Shape (n_clusters,).
712+
kept_spikes : np.ndarray.
713+
Boolean mask with shape (n_spikes,) that is False for spikes that were
714+
removed by `kilosort.postprocessing.remove_duplicate_spikes`
715+
and True otherwise.
716+
717+
Notes
718+
-----
719+
For documentation of saved files, see `kilosort.io.save_to_phy`.
720+
634721
"""
635722

636723
logger.info(' ')

0 commit comments

Comments
 (0)