Skip to content

Commit c00f68e

Browse files
unwhitened template amplitudes, saved normalized templates
1 parent 2483eb1 commit c00f68e

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

kilosort/io.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,18 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
156156
np.save((results_dir / 'channel_positions.npy'), channel_positions)
157157

158158
# whitening matrix ** saving real whitening matrix doesn't work with phy currently
159-
whitening_mat = ops['Wrot'].cpu().numpy()
160-
np.save((results_dir / 'whitening_mat_dat.npy'), whitening_mat)
159+
whitening_mat = ops['Wrot']
160+
np.save((results_dir / 'whitening_mat_dat.npy'), whitening_mat.cpu())
161161
# NOTE: commented out for reference, this was different in KS 2.5 because
162162
# the binary file was already whitened.
163163
# whitening_mat = 0.005 * np.eye(len(chan_map), dtype='float32')
164-
whitening_mat_inv = np.linalg.inv(whitening_mat + 1e-5 * np.eye(whitening_mat.shape[0]))
165-
np.save((results_dir / 'whitening_mat.npy'), whitening_mat)
166-
np.save((results_dir / 'whitening_mat_inv.npy'), whitening_mat_inv)
164+
whitening_mat_inv = torch.inverse(
165+
whitening_mat
166+
+ 1e-5 * torch.eye(whitening_mat.shape[0]).to(whitening_mat.device)
167+
)
168+
#whitening_mat_inv = np.linalg.inv(whitening_mat + 1e-5 * np.eye(whitening_mat.shape[0]))
169+
np.save((results_dir / 'whitening_mat.npy'), whitening_mat.cpu())
170+
np.save((results_dir / 'whitening_mat_inv.npy'), whitening_mat_inv.cpu())
167171

168172
# spike properties
169173
spike_times = st[:,0].astype('int64') + imin # shift by minimum sample index
@@ -191,13 +195,19 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
191195

192196
# template properties
193197
similar_templates = CCG.similarity(Wall, ops['wPCA'].contiguous(), nt=ops['nt'])
194-
template_amplitudes = ((Wall**2).sum(axis=(-2,-1))**0.5).cpu().numpy()
198+
temp_amplitudes = ((Wall**2).sum(axis=(-2,-1))**0.5).cpu().numpy()
195199
templates = (Wall.unsqueeze(-1).cpu() * ops['wPCA'].cpu()).sum(axis=-2).numpy()
196200
templates = templates.transpose(0,2,1)
201+
# normalize templates by amplitude
202+
templates = templates / temp_amplitudes[:, np.newaxis, np.newaxis]
197203
templates_ind = np.tile(np.arange(Wall.shape[1])[np.newaxis, :], (templates.shape[0],1))
198204
np.save((results_dir / 'similar_templates.npy'), similar_templates)
199205
np.save((results_dir / 'templates.npy'), templates)
200206
np.save((results_dir / 'templates_ind.npy'), templates_ind)
207+
# get unwhitened template amplitudes to use as cluster_Amplitudes
208+
iwrot = whitening_mat_inv.to(Wall.device)
209+
unwhitened = torch.einsum('jk, ikl -> ijl', iwrot, Wall)
210+
template_amplitudes = ((unwhitened**2).sum(axis=(-2,-1))**0.5).cpu().numpy()
201211

202212
# contamination ratio
203213
acg_threshold = ops['settings']['acg_threshold']

0 commit comments

Comments
 (0)