@@ -99,8 +99,38 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
99
99
100
100
Returns
101
101
-------
102
- ops, st, clu, tF, Wall, similar_templates, is_ref, est_contam_rate, kept_spikes
103
- Description TODO
102
+ ops : dict
103
+ Dictionary storing settings and results for all algorithmic steps.
104
+ st : np.ndarray
105
+ 3-column array of peak time (in samples), template, and amplitude for
106
+ each spike.
107
+ clu : np.ndarray
108
+ 1D vector of cluster ids indicating which spike came from which cluster,
109
+ same shape as `st[:,0]`.
110
+ tF : torch.Tensor
111
+ PC features for each spike, with shape
112
+ (n_spikes, nearest_chans, n_pcs)
113
+ Wall : torch.Tensor
114
+ PC feature representation of spike waveforms for each cluster, with shape
115
+ (n_clusters, n_channels, n_pcs).
116
+ similar_templates : np.ndarray.
117
+ Similarity score between each pair of clusters, computed as correlation
118
+ between clusters. Shape (n_clusters, n_clusters).
119
+ is_ref : np.ndarray.
120
+ 1D boolean array with shape (n_clusters,) indicating whether each
121
+ cluster is refractory.
122
+ est_contam_rate : np.ndarray.
123
+ Contamination rate for each cluster, computed as fraction of refractory
124
+ period violations relative to expectation based on a Poisson process.
125
+ Shape (n_clusters,).
126
+ kept_spikes : np.ndarray.
127
+ Boolean mask with shape (n_spikes,) that is False for spikes that were
128
+ removed by `kilosort.postprocessing.remove_duplicate_spikes`
129
+ and True otherwise.
130
+
131
+ Notes
132
+ -----
133
+ For documentation of saved files, see `kilosort.io.save_to_phy`.
104
134
105
135
"""
106
136
@@ -447,8 +477,12 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
447
477
Returns
448
478
-------
449
479
ops : dict
480
+ Dictionary storing settings and results for all algorithmic steps.
450
481
bfile : kilosort.io.BinaryFiltered
451
482
Wrapped file object for handling data.
483
+ st0 : np.ndarray.
484
+ Intermediate spike times variable with 6 columns. This is only used
485
+ for generating the 'Drift Scatter' plot through the GUI.
452
486
453
487
"""
454
488
@@ -493,7 +527,7 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
493
527
494
528
495
529
def detect_spikes (ops , device , bfile , tic0 = np .nan , progress_bar = None ):
496
- """Run spike sorting algorithm and save intermediate results to `ops` .
530
+ """Detect spikes via template deconvolution .
497
531
498
532
Parameters
499
533
----------
@@ -511,14 +545,17 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
511
545
Returns
512
546
-------
513
547
st : np.ndarray
514
- 1D vector of spike times for all clusters.
548
+ 3-column array of peak time (in samples), template, and amplitude for
549
+ each spike.
515
550
clu : np.ndarray
516
551
1D vector of cluster ids indicating which spike came from which cluster,
517
552
same shape as `st`.
518
- tF : np.ndarray
519
- TODO
520
- Wall : np.ndarray
521
- TODO
553
+ tF : torch.Tensor
554
+ PC features for each spike, with shape
555
+ (n_spikes, nearest_chans, n_pcs)
556
+ Wall : torch.Tensor
557
+ PC feature representation of spike waveforms for each cluster, with shape
558
+ (n_clusters, n_channels, n_pcs).
522
559
523
560
"""
524
561
@@ -564,6 +601,37 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
564
601
565
602
566
603
def cluster_spikes (st , tF , ops , device , bfile , tic0 = np .nan , progress_bar = None ):
604
+ """Cluster spikes using graph-based methods.
605
+
606
+ Parameters
607
+ ----------
608
+ st : np.ndarray
609
+ 3-column array of peak time (in samples), template, and amplitude for
610
+ each spike.
611
+ tF : torch.Tensor
612
+ PC features for each spike, with shape
613
+ (n_spikes, nearest_chans, n_pcs)
614
+ ops : dict
615
+ Dictionary storing settings and results for all algorithmic steps.
616
+ device : torch.device
617
+ Indicates whether `pytorch` operations should be run on cpu or gpu.
618
+ bfile : kilosort.io.BinaryFiltered
619
+ Wrapped file object for handling data.
620
+ tic0 : float; default=np.nan.
621
+ Start time of `run_kilosort`.
622
+ progress_bar : TODO; optional.
623
+ Informs `tqdm` package how to report progress, type unclear.
624
+
625
+ Returns
626
+ -------
627
+ clu : np.ndarray
628
+ 1D vector of cluster ids indicating which spike came from which cluster,
629
+ same shape as `st`.
630
+ Wall : torch.Tensor
631
+ PC feature representation of spike waveforms for each cluster, with shape
632
+ (n_clusters, n_channels, n_pcs).
633
+
634
+ """
567
635
tic = time .time ()
568
636
logger .info (' ' )
569
637
logger .info ('Final clustering' )
@@ -603,21 +671,25 @@ def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan,
603
671
results_dir : pathlib.Path
604
672
Directory where results should be saved.
605
673
st : np.ndarray
606
- 1D vector of spike times for all clusters.
674
+ 3-column array of peak time (in samples), template, and amplitude for
675
+ each spike.
607
676
clu : np.ndarray
608
677
1D vector of cluster ids indicating which spike came from which cluster,
609
- same shape as `st`.
610
- tF : np.ndarray
611
- TODO
612
- Wall : np.ndarray
613
- TODO
678
+ same shape as `st[:,0]`.
679
+ tF : torch.Tensor
680
+ PC features for each spike, with shape
681
+ (n_spikes, nearest_chans, n_pcs)
682
+ Wall : torch.Tensor
683
+ PC feature representation of spike waveforms for each cluster, with shape
684
+ (n_clusters, n_channels, n_pcs).
614
685
imin : int
615
686
Minimum sample index used by BinaryRWFile, exported spike times will
616
687
be shifted forward by this number.
617
688
tic0 : float; default=np.nan.
618
689
Start time of `run_kilosort`.
619
690
save_extra_vars : bool; default=False.
620
- If True, save tF and Wall to disk after sorting.
691
+ If True, save tF and Wall to disk along with copies of st, clu and
692
+ amplitudes with no postprocessing applied.
621
693
save_preprocessed_copy : bool; default=False.
622
694
If True, save a pre-processed copy of the data (including drift
623
695
correction) to `temp_wh.dat` in the results directory and format Phy
@@ -626,11 +698,26 @@ def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan,
626
698
Returns
627
699
-------
628
700
ops : dict
629
- similar_templates : np.ndarray
630
- is_ref : np.ndarray
631
- est_contam_rate : np.ndarray
632
- kept_spikes : np.ndarray
633
-
701
+ Dictionary storing settings and results for all algorithmic steps.
702
+ similar_templates : np.ndarray.
703
+ Similarity score between each pair of clusters, computed as correlation
704
+ between clusters. Shape (n_clusters, n_clusters).
705
+ is_ref : np.ndarray.
706
+ 1D boolean array with shape (n_clusters,) indicating whether each
707
+ cluster is refractory.
708
+ est_contam_rate : np.ndarray.
709
+ Contamination rate for each cluster, computed as fraction of refractory
710
+ period violations relative to expectation based on a Poisson process.
711
+ Shape (n_clusters,).
712
+ kept_spikes : np.ndarray.
713
+ Boolean mask with shape (n_spikes,) that is False for spikes that were
714
+ removed by `kilosort.postprocessing.remove_duplicate_spikes`
715
+ and True otherwise.
716
+
717
+ Notes
718
+ -----
719
+ For documentation of saved files, see `kilosort.io.save_to_phy`.
720
+
634
721
"""
635
722
636
723
logger .info (' ' )
0 commit comments