From 3693ad7e6fa8d94db4e4b2628ced534a0c69a063 Mon Sep 17 00:00:00 2001 From: jacobpennington Date: Tue, 12 Mar 2024 09:20:07 -0700 Subject: [PATCH 1/7] Fixed tF, Wall in load_sorting, should always be cpu --- kilosort/run_kilosort.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kilosort/run_kilosort.py b/kilosort/run_kilosort.py index a7529c6d..479a2e16 100644 --- a/kilosort/run_kilosort.py +++ b/kilosort/run_kilosort.py @@ -512,10 +512,11 @@ def load_sorting(results_dir, device=None, load_extra_vars=False): results = [ops, st, clu, similar_templates, is_ref, est_contam_rate] if load_extra_vars: + # NOTE: tF and Wall always go on CPU, not CUDA tF = np.load(results_dir / 'tF.npy') - tF = torch.from_numpy(tF).to(device) + tF = torch.from_numpy(tF) Wall = np.load(results_dir / 'Wall.npy') - Wall = torch.from_numpy(Wall).to(device) + Wall = torch.from_numpy(Wall) full_st = np.load(results_dir / 'full_st.npy') full_clu = np.load(results_dir / 'full_clu.npy') full_amp = np.load(results_dir / 'full_amp.npy') From e68b44aa4b9f5351920f6e9d4f471c814fda5f43 Mon Sep 17 00:00:00 2001 From: jacobpennington Date: Tue, 12 Mar 2024 09:20:45 -0700 Subject: [PATCH 2/7] Added options to get_data_cpu for exporting pc features --- kilosort/clustering_qr.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/kilosort/clustering_qr.py b/kilosort/clustering_qr.py index 43b877c2..5aa80db0 100644 --- a/kilosort/clustering_qr.py +++ b/kilosort/clustering_qr.py @@ -328,7 +328,8 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_b return clu, Wall -def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin = 20, dminx = 32, ncomps = 64): +def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin=20, dminx=32, + ncomps=64, ix=None, merge_dim=True): PID = torch.from_numpy(PID).long() #iU = ops['iU'].cpu().numpy() @@ -341,7 +342,11 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin = 20, dminx = 32, x0 = xcenter #xy[0].mean() - xcenter #print(dmin, dminx) - ix = torch.logical_and(torch.abs(xy[1] - y0) < dmin, torch.abs(xy[0] - x0) < dminx) + if ix is None: + ix = torch.logical_and( + torch.abs(xy[1] - y0) < dmin, + torch.abs(xy[0] - x0) < dminx + ) #print(ix.nonzero()[:,0]) igood = ix[PID].nonzero()[:,0] @@ -362,7 +367,12 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin = 20, dminx = 32, #print(ij.sum()) dd[ij.unsqueeze(-1), iC[:,j]-ch_min] = data[ij] - Xd = torch.reshape(dd, (nspikes, -1)) + if merge_dim: + Xd = torch.reshape(dd, (nspikes, -1)) + else: + # Keep channels and features separate + Xd = torch.reshape(dd, (nspikes, -1, 6)) + return Xd, ch_min, ch_max, igood From 5c329c2bc4916cbc3f6f568ca94e1179dc256028 Mon Sep 17 00:00:00 2001 From: Jacob Pennington Date: Fri, 22 Mar 2024 11:05:15 -0400 Subject: [PATCH 3/7] small fix to get data cpu --- kilosort/clustering_qr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kilosort/clustering_qr.py b/kilosort/clustering_qr.py index 5aa80db0..09064ac2 100644 --- a/kilosort/clustering_qr.py +++ b/kilosort/clustering_qr.py @@ -371,7 +371,7 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin=20, dminx=32, Xd = torch.reshape(dd, (nspikes, -1)) else: # Keep channels and features separate - Xd = torch.reshape(dd, (nspikes, -1, 6)) + Xd = dd return Xd, ch_min, ch_max, igood From a49e07e8ff35e6ef849ba3be624b9725d6485c13 Mon Sep 17 00:00:00 2001 From: Jacob Pennington Date: Wed, 27 Mar 2024 09:53:07 -0400 Subject: [PATCH 4/7] Added pc features to phy output --- kilosort/io.py | 20 +++++++++++++++++--- kilosort/postprocessing.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/kilosort/io.py b/kilosort/io.py index 2755447d..a8c9e6ae 100644 --- a/kilosort/io.py +++ b/kilosort/io.py @@ -13,7 +13,9 @@ from kilosort import CCG from kilosort.preprocessing import get_drift_matrix, fft_highpass -from kilosort.postprocessing import remove_duplicates, compute_spike_positions +from kilosort.postprocessing import ( + remove_duplicates, compute_spike_positions, make_pc_features + ) def find_binary(data_dir: Union[str, os.PathLike]) -> Path: @@ -195,6 +197,19 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None, np.save((results_dir / 'templates.npy'), templates) np.save((results_dir / 'templates_ind.npy'), templates_ind) + # pc features + if save_extra_vars: + # Save tF first since it gets updated in-place + np.save(results_dir / 'tF.npy', tF.cpu().numpy()) + # This will momentarily copy tF which is pretty large, but it's on CPU + # so the extra memory hopefully won't be an issue. + tF = tF[kept_spikes] + pc_features, pc_feature_ind = make_pc_features( + ops, spike_templates, spike_clusters, tF + ) + np.save(results_dir / 'pc_features.npy', pc_features) + np.save(results_dir / 'pc_feature_ind.npy', pc_feature_ind) + # contamination ratio acg_threshold = ops['settings']['acg_threshold'] ccg_threshold = ops['settings']['ccg_threshold'] @@ -231,8 +246,7 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None, f.write(f'{key} = {params[key]}\n') if save_extra_vars: - # Also save tF and Wall, for easier debugging/analysis - np.save(results_dir / 'tF.npy', tF.cpu().numpy()) + # Also save Wall, for easier debugging/analysis np.save(results_dir / 'Wall.npy', Wall.cpu().numpy()) # And full st, clu, amp arrays with no spikes removed np.save(results_dir / 'full_st.npy', st) diff --git a/kilosort/postprocessing.py b/kilosort/postprocessing.py index 4244d883..a06656ea 100644 --- a/kilosort/postprocessing.py +++ b/kilosort/postprocessing.py @@ -3,6 +3,8 @@ import numpy as np import torch +from kilosort.clustering_qr import xy_templates, get_data_cpu + @njit("(int64[:], int32[:], int32)") def remove_duplicates(spike_times, spike_clusters, dt=15): @@ -42,3 +44,36 @@ def compute_spike_positions(st, tF, ops): ys = (yc0 * tmass).sum(1).cpu().numpy() return xs, ys + + +def make_pc_features(ops, spike_templates, spike_clusters, tF): + # spike_templates: st[:,1] + # spike clusters: clu + + xy, iC = xy_templates(ops) + n_clusters = np.unique(spike_clusters).size + feature_ind = np.zeros((n_clusters, 10), dtype=np.uint32) + + for i in np.unique(spike_clusters): + iunq = np.unique(spike_templates[spike_clusters==i]).astype(int) + ix = torch.from_numpy(np.zeros(int(spike_templates.max())+1, bool)) + ix[iunq] = True + Xd, ch_min, ch_max, igood = get_data_cpu( + ops, xy, iC, spike_templates, tF, None, None, + dmin=ops['dmin'], dminx=ops['dminx'], ix=ix, merge_dim=False + ) + + # Take mean of Xd across spikes, find channels w/ largest norm + spike_mean = Xd.mean(0) + chan_norm = torch.linalg.norm(spike_mean, dim=1) + sorted_chans, ind = torch.sort(chan_norm, descending=True) + # Assign Xd to overwrite tF in-place + tF[igood,:] = Xd[:, ind[:10], :] + # Save channel inds for phy + feature_ind[i,:] = ind[:10].numpy() + ch_min.cpu().numpy() + # TODO: should be sorted by physical distance from first channel? + # TODO: cast to uint32 + + tF = torch.permute(tF, (0, 2, 1)) + + return tF, feature_ind From 2d628432a067e5821556fafc2f07e8df19cc2426 Mon Sep 17 00:00:00 2001 From: Jacob Pennington Date: Wed, 27 Mar 2024 10:25:25 -0400 Subject: [PATCH 5/7] Added documentation for make_pc_features --- kilosort/postprocessing.py | 47 ++++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/kilosort/postprocessing.py b/kilosort/postprocessing.py index a06656ea..fddd33e3 100644 --- a/kilosort/postprocessing.py +++ b/kilosort/postprocessing.py @@ -47,33 +47,66 @@ def compute_spike_positions(st, tF, ops): def make_pc_features(ops, spike_templates, spike_clusters, tF): - # spike_templates: st[:,1] - # spike clusters: clu + '''Get PC Features and corresponding indices for export to Phy. + NOTE: This function will update tF in-place! + + Parameters + ---------- + ops : dict + Dictionary of state variables updated throughout the sorting process. + This function is intended to be used with the final state of ops, after + all sorting has finished. + spike_templates : np.ndarray + Vector of template ids with shape `(n_spikes,)`. This is equivalent to + `st[:,1]`, where `st` is returned by `template_matching.extract`. + spike_clusters : np.ndarray + Vector of cluster ids with shape `(n_pikes,)`. This is equivalent to + `clu` returned by `template_matching.merging_function`. + tF : torch.Tensor + Tensor of pc features as returned by `template_matching.extract`, + with shape `(n_spikes, nearest_chans, n_pcs)`. + + Returns + ------- + tF : torch.Tensor + As above, but with some data replaced so that features are associated + with the final clusters instead of templates. The second and third + dimensions are also swapped to conform to the shape expected by Phy. + feature_ind : np.ndarray + Channel indices associated with the data present in tF for each cluster, + with shape `(n_clusters, nearest_chans)`. + + ''' + + # xy: template centers, iC: channels associated with each template xy, iC = xy_templates(ops) n_clusters = np.unique(spike_clusters).size - feature_ind = np.zeros((n_clusters, 10), dtype=np.uint32) + feature_ind = np.zeros((n_clusters, ops['nearest_chans']), dtype=np.uint32) for i in np.unique(spike_clusters): + # Get templates associated with cluster (often just 1) iunq = np.unique(spike_templates[spike_clusters==i]).astype(int) + # Get boolean mask with size (n_templates,), True if they match cluster ix = torch.from_numpy(np.zeros(int(spike_templates.max())+1, bool)) ix[iunq] = True + # Get PC features for all spikes detected with those templates (Xd), + # and the indices in tF where those spikes occur (igood). Xd, ch_min, ch_max, igood = get_data_cpu( ops, xy, iC, spike_templates, tF, None, None, dmin=ops['dmin'], dminx=ops['dminx'], ix=ix, merge_dim=False ) - # Take mean of Xd across spikes, find channels w/ largest norm + # Take mean of features across spikes, find channels w/ largest norm spike_mean = Xd.mean(0) chan_norm = torch.linalg.norm(spike_mean, dim=1) sorted_chans, ind = torch.sort(chan_norm, descending=True) - # Assign Xd to overwrite tF in-place + # Assign features to overwrite tF in-place tF[igood,:] = Xd[:, ind[:10], :] # Save channel inds for phy feature_ind[i,:] = ind[:10].numpy() + ch_min.cpu().numpy() - # TODO: should be sorted by physical distance from first channel? - # TODO: cast to uint32 + # Swap last 2 dimensions to get ordering Phy expects tF = torch.permute(tF, (0, 2, 1)) return tF, feature_ind From 78958a2fe6fdbd2c604dad984c7644fdb9b6bd04 Mon Sep 17 00:00:00 2001 From: Jacob Pennington Date: Wed, 27 Mar 2024 10:52:23 -0400 Subject: [PATCH 6/7] .phy cache now removed by save_to_phy when overwriting --- kilosort/io.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/kilosort/io.py b/kilosort/io.py index a8c9e6ae..f6bc7d45 100644 --- a/kilosort/io.py +++ b/kilosort/io.py @@ -253,6 +253,12 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None, np.save(results_dir / 'full_clu.npy', clu) np.save(results_dir / 'full_amp.npy', amplitudes) + # Remove cached .phy results if present from running Phy on a previous + # version of results in the same directory. + phy_cache_path = Path(results_dir / '.phy') + if phy_cache_path.is_dir(): + shutil.rmtree(phy_cache_path) + return results_dir, similar_templates, is_ref, est_contam_rate From 207d4a3585f91ece67b69edceb6599ecfb4eddf6 Mon Sep 17 00:00:00 2001 From: Jacob Pennington Date: Wed, 27 Mar 2024 10:52:51 -0400 Subject: [PATCH 7/7] removed hard-coded nearest chans in make_pc_features --- kilosort/postprocessing.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/kilosort/postprocessing.py b/kilosort/postprocessing.py index fddd33e3..466fa7d5 100644 --- a/kilosort/postprocessing.py +++ b/kilosort/postprocessing.py @@ -82,7 +82,8 @@ def make_pc_features(ops, spike_templates, spike_clusters, tF): # xy: template centers, iC: channels associated with each template xy, iC = xy_templates(ops) n_clusters = np.unique(spike_clusters).size - feature_ind = np.zeros((n_clusters, ops['nearest_chans']), dtype=np.uint32) + n_chans = ops['nearest_chans'] + feature_ind = np.zeros((n_clusters, n_chans), dtype=np.uint32) for i in np.unique(spike_clusters): # Get templates associated with cluster (often just 1) @@ -102,9 +103,9 @@ def make_pc_features(ops, spike_templates, spike_clusters, tF): chan_norm = torch.linalg.norm(spike_mean, dim=1) sorted_chans, ind = torch.sort(chan_norm, descending=True) # Assign features to overwrite tF in-place - tF[igood,:] = Xd[:, ind[:10], :] + tF[igood,:] = Xd[:, ind[:n_chans], :] # Save channel inds for phy - feature_ind[i,:] = ind[:10].numpy() + ch_min.cpu().numpy() + feature_ind[i,:] = ind[:n_chans].numpy() + ch_min.cpu().numpy() # Swap last 2 dimensions to get ordering Phy expects tF = torch.permute(tF, (0, 2, 1))