Skip to content

Commit

Permalink
unwhitened template amplitudes, saved normalized templates
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Mar 27, 2024
1 parent 2483eb1 commit c00f68e
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions kilosort/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,18 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
np.save((results_dir / 'channel_positions.npy'), channel_positions)

# whitening matrix ** saving real whitening matrix doesn't work with phy currently
whitening_mat = ops['Wrot'].cpu().numpy()
np.save((results_dir / 'whitening_mat_dat.npy'), whitening_mat)
whitening_mat = ops['Wrot']
np.save((results_dir / 'whitening_mat_dat.npy'), whitening_mat.cpu())
# NOTE: commented out for reference, this was different in KS 2.5 because
# the binary file was already whitened.
# whitening_mat = 0.005 * np.eye(len(chan_map), dtype='float32')
whitening_mat_inv = np.linalg.inv(whitening_mat + 1e-5 * np.eye(whitening_mat.shape[0]))
np.save((results_dir / 'whitening_mat.npy'), whitening_mat)
np.save((results_dir / 'whitening_mat_inv.npy'), whitening_mat_inv)
whitening_mat_inv = torch.inverse(
whitening_mat
+ 1e-5 * torch.eye(whitening_mat.shape[0]).to(whitening_mat.device)
)
#whitening_mat_inv = np.linalg.inv(whitening_mat + 1e-5 * np.eye(whitening_mat.shape[0]))
np.save((results_dir / 'whitening_mat.npy'), whitening_mat.cpu())
np.save((results_dir / 'whitening_mat_inv.npy'), whitening_mat_inv.cpu())

# spike properties
spike_times = st[:,0].astype('int64') + imin # shift by minimum sample index
Expand Down Expand Up @@ -191,13 +195,19 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,

# template properties
similar_templates = CCG.similarity(Wall, ops['wPCA'].contiguous(), nt=ops['nt'])
template_amplitudes = ((Wall**2).sum(axis=(-2,-1))**0.5).cpu().numpy()
temp_amplitudes = ((Wall**2).sum(axis=(-2,-1))**0.5).cpu().numpy()
templates = (Wall.unsqueeze(-1).cpu() * ops['wPCA'].cpu()).sum(axis=-2).numpy()
templates = templates.transpose(0,2,1)
# normalize templates by amplitude
templates = templates / temp_amplitudes[:, np.newaxis, np.newaxis]
templates_ind = np.tile(np.arange(Wall.shape[1])[np.newaxis, :], (templates.shape[0],1))
np.save((results_dir / 'similar_templates.npy'), similar_templates)
np.save((results_dir / 'templates.npy'), templates)
np.save((results_dir / 'templates_ind.npy'), templates_ind)
# get unwhitened template amplitudes to use as cluster_Amplitudes
iwrot = whitening_mat_inv.to(Wall.device)
unwhitened = torch.einsum('jk, ikl -> ijl', iwrot, Wall)
template_amplitudes = ((unwhitened**2).sum(axis=(-2,-1))**0.5).cpu().numpy()

# contamination ratio
acg_threshold = ops['settings']['acg_threshold']
Expand Down

0 comments on commit c00f68e

Please sign in to comment.