diff --git a/kilosort/clustering_qr.py b/kilosort/clustering_qr.py index 43b877c2..09064ac2 100644 --- a/kilosort/clustering_qr.py +++ b/kilosort/clustering_qr.py @@ -328,7 +328,8 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_b return clu, Wall -def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin = 20, dminx = 32, ncomps = 64): +def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin=20, dminx=32, + ncomps=64, ix=None, merge_dim=True): PID = torch.from_numpy(PID).long() #iU = ops['iU'].cpu().numpy() @@ -341,7 +342,11 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin = 20, dminx = 32, x0 = xcenter #xy[0].mean() - xcenter #print(dmin, dminx) - ix = torch.logical_and(torch.abs(xy[1] - y0) < dmin, torch.abs(xy[0] - x0) < dminx) + if ix is None: + ix = torch.logical_and( + torch.abs(xy[1] - y0) < dmin, + torch.abs(xy[0] - x0) < dminx + ) #print(ix.nonzero()[:,0]) igood = ix[PID].nonzero()[:,0] @@ -362,7 +367,12 @@ def get_data_cpu(ops, xy, iC, PID, tF, ycenter, xcenter, dmin = 20, dminx = 32, #print(ij.sum()) dd[ij.unsqueeze(-1), iC[:,j]-ch_min] = data[ij] - Xd = torch.reshape(dd, (nspikes, -1)) + if merge_dim: + Xd = torch.reshape(dd, (nspikes, -1)) + else: + # Keep channels and features separate + Xd = dd + return Xd, ch_min, ch_max, igood diff --git a/kilosort/io.py b/kilosort/io.py index 2755447d..f6bc7d45 100644 --- a/kilosort/io.py +++ b/kilosort/io.py @@ -13,7 +13,9 @@ from kilosort import CCG from kilosort.preprocessing import get_drift_matrix, fft_highpass -from kilosort.postprocessing import remove_duplicates, compute_spike_positions +from kilosort.postprocessing import ( + remove_duplicates, compute_spike_positions, make_pc_features + ) def find_binary(data_dir: Union[str, os.PathLike]) -> Path: @@ -195,6 +197,19 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None, np.save((results_dir / 'templates.npy'), templates) np.save((results_dir / 'templates_ind.npy'), templates_ind) + # pc features + if save_extra_vars: + # Save tF first since it gets updated in-place + np.save(results_dir / 'tF.npy', tF.cpu().numpy()) + # This will momentarily copy tF which is pretty large, but it's on CPU + # so the extra memory hopefully won't be an issue. + tF = tF[kept_spikes] + pc_features, pc_feature_ind = make_pc_features( + ops, spike_templates, spike_clusters, tF + ) + np.save(results_dir / 'pc_features.npy', pc_features) + np.save(results_dir / 'pc_feature_ind.npy', pc_feature_ind) + # contamination ratio acg_threshold = ops['settings']['acg_threshold'] ccg_threshold = ops['settings']['ccg_threshold'] @@ -231,14 +246,19 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None, f.write(f'{key} = {params[key]}\n') if save_extra_vars: - # Also save tF and Wall, for easier debugging/analysis - np.save(results_dir / 'tF.npy', tF.cpu().numpy()) + # Also save Wall, for easier debugging/analysis np.save(results_dir / 'Wall.npy', Wall.cpu().numpy()) # And full st, clu, amp arrays with no spikes removed np.save(results_dir / 'full_st.npy', st) np.save(results_dir / 'full_clu.npy', clu) np.save(results_dir / 'full_amp.npy', amplitudes) + # Remove cached .phy results if present from running Phy on a previous + # version of results in the same directory. + phy_cache_path = Path(results_dir / '.phy') + if phy_cache_path.is_dir(): + shutil.rmtree(phy_cache_path) + return results_dir, similar_templates, is_ref, est_contam_rate diff --git a/kilosort/postprocessing.py b/kilosort/postprocessing.py index 4244d883..466fa7d5 100644 --- a/kilosort/postprocessing.py +++ b/kilosort/postprocessing.py @@ -3,6 +3,8 @@ import numpy as np import torch +from kilosort.clustering_qr import xy_templates, get_data_cpu + @njit("(int64[:], int32[:], int32)") def remove_duplicates(spike_times, spike_clusters, dt=15): @@ -42,3 +44,70 @@ def compute_spike_positions(st, tF, ops): ys = (yc0 * tmass).sum(1).cpu().numpy() return xs, ys + + +def make_pc_features(ops, spike_templates, spike_clusters, tF): + '''Get PC Features and corresponding indices for export to Phy. + + NOTE: This function will update tF in-place! + + Parameters + ---------- + ops : dict + Dictionary of state variables updated throughout the sorting process. + This function is intended to be used with the final state of ops, after + all sorting has finished. + spike_templates : np.ndarray + Vector of template ids with shape `(n_spikes,)`. This is equivalent to + `st[:,1]`, where `st` is returned by `template_matching.extract`. + spike_clusters : np.ndarray + Vector of cluster ids with shape `(n_pikes,)`. This is equivalent to + `clu` returned by `template_matching.merging_function`. + tF : torch.Tensor + Tensor of pc features as returned by `template_matching.extract`, + with shape `(n_spikes, nearest_chans, n_pcs)`. + + Returns + ------- + tF : torch.Tensor + As above, but with some data replaced so that features are associated + with the final clusters instead of templates. The second and third + dimensions are also swapped to conform to the shape expected by Phy. + feature_ind : np.ndarray + Channel indices associated with the data present in tF for each cluster, + with shape `(n_clusters, nearest_chans)`. + + ''' + + # xy: template centers, iC: channels associated with each template + xy, iC = xy_templates(ops) + n_clusters = np.unique(spike_clusters).size + n_chans = ops['nearest_chans'] + feature_ind = np.zeros((n_clusters, n_chans), dtype=np.uint32) + + for i in np.unique(spike_clusters): + # Get templates associated with cluster (often just 1) + iunq = np.unique(spike_templates[spike_clusters==i]).astype(int) + # Get boolean mask with size (n_templates,), True if they match cluster + ix = torch.from_numpy(np.zeros(int(spike_templates.max())+1, bool)) + ix[iunq] = True + # Get PC features for all spikes detected with those templates (Xd), + # and the indices in tF where those spikes occur (igood). + Xd, ch_min, ch_max, igood = get_data_cpu( + ops, xy, iC, spike_templates, tF, None, None, + dmin=ops['dmin'], dminx=ops['dminx'], ix=ix, merge_dim=False + ) + + # Take mean of features across spikes, find channels w/ largest norm + spike_mean = Xd.mean(0) + chan_norm = torch.linalg.norm(spike_mean, dim=1) + sorted_chans, ind = torch.sort(chan_norm, descending=True) + # Assign features to overwrite tF in-place + tF[igood,:] = Xd[:, ind[:n_chans], :] + # Save channel inds for phy + feature_ind[i,:] = ind[:n_chans].numpy() + ch_min.cpu().numpy() + + # Swap last 2 dimensions to get ordering Phy expects + tF = torch.permute(tF, (0, 2, 1)) + + return tF, feature_ind diff --git a/kilosort/run_kilosort.py b/kilosort/run_kilosort.py index a7529c6d..479a2e16 100644 --- a/kilosort/run_kilosort.py +++ b/kilosort/run_kilosort.py @@ -512,10 +512,11 @@ def load_sorting(results_dir, device=None, load_extra_vars=False): results = [ops, st, clu, similar_templates, is_ref, est_contam_rate] if load_extra_vars: + # NOTE: tF and Wall always go on CPU, not CUDA tF = np.load(results_dir / 'tF.npy') - tF = torch.from_numpy(tF).to(device) + tF = torch.from_numpy(tF) Wall = np.load(results_dir / 'Wall.npy') - Wall = torch.from_numpy(Wall).to(device) + Wall = torch.from_numpy(Wall) full_st = np.load(results_dir / 'full_st.npy') full_clu = np.load(results_dir / 'full_clu.npy') full_amp = np.load(results_dir / 'full_amp.npy')