@@ -156,14 +156,18 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
156156 np .save ((results_dir / 'channel_positions.npy' ), channel_positions )
157157
158158 # whitening matrix ** saving real whitening matrix doesn't work with phy currently
159- whitening_mat = ops ['Wrot' ]. cpu (). numpy ()
160- np .save ((results_dir / 'whitening_mat_dat.npy' ), whitening_mat )
159+ whitening_mat = ops ['Wrot' ]
160+ np .save ((results_dir / 'whitening_mat_dat.npy' ), whitening_mat . cpu () )
161161 # NOTE: commented out for reference, this was different in KS 2.5 because
162162 # the binary file was already whitened.
163163 # whitening_mat = 0.005 * np.eye(len(chan_map), dtype='float32')
164- whitening_mat_inv = np .linalg .inv (whitening_mat + 1e-5 * np .eye (whitening_mat .shape [0 ]))
165- np .save ((results_dir / 'whitening_mat.npy' ), whitening_mat )
166- np .save ((results_dir / 'whitening_mat_inv.npy' ), whitening_mat_inv )
164+ whitening_mat_inv = torch .inverse (
165+ whitening_mat
166+ + 1e-5 * torch .eye (whitening_mat .shape [0 ]).to (whitening_mat .device )
167+ )
168+ #whitening_mat_inv = np.linalg.inv(whitening_mat + 1e-5 * np.eye(whitening_mat.shape[0]))
169+ np .save ((results_dir / 'whitening_mat.npy' ), whitening_mat .cpu ())
170+ np .save ((results_dir / 'whitening_mat_inv.npy' ), whitening_mat_inv .cpu ())
167171
168172 # spike properties
169173 spike_times = st [:,0 ].astype ('int64' ) + imin # shift by minimum sample index
@@ -191,13 +195,19 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
191195
192196 # template properties
193197 similar_templates = CCG .similarity (Wall , ops ['wPCA' ].contiguous (), nt = ops ['nt' ])
194- template_amplitudes = ((Wall ** 2 ).sum (axis = (- 2 ,- 1 ))** 0.5 ).cpu ().numpy ()
198+ temp_amplitudes = ((Wall ** 2 ).sum (axis = (- 2 ,- 1 ))** 0.5 ).cpu ().numpy ()
195199 templates = (Wall .unsqueeze (- 1 ).cpu () * ops ['wPCA' ].cpu ()).sum (axis = - 2 ).numpy ()
196200 templates = templates .transpose (0 ,2 ,1 )
201+ # normalize templates by amplitude
202+ templates = templates / temp_amplitudes [:, np .newaxis , np .newaxis ]
197203 templates_ind = np .tile (np .arange (Wall .shape [1 ])[np .newaxis , :], (templates .shape [0 ],1 ))
198204 np .save ((results_dir / 'similar_templates.npy' ), similar_templates )
199205 np .save ((results_dir / 'templates.npy' ), templates )
200206 np .save ((results_dir / 'templates_ind.npy' ), templates_ind )
207+ # get unwhitened template amplitudes to use as cluster_Amplitudes
208+ iwrot = whitening_mat_inv .to (Wall .device )
209+ unwhitened = torch .einsum ('jk, ikl -> ijl' , iwrot , Wall )
210+ template_amplitudes = ((unwhitened ** 2 ).sum (axis = (- 2 ,- 1 ))** 0.5 ).cpu ().numpy ()
201211
202212 # contamination ratio
203213 acg_threshold = ops ['settings' ]['acg_threshold' ]
0 commit comments