@@ -156,14 +156,18 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
156
156
np .save ((results_dir / 'channel_positions.npy' ), channel_positions )
157
157
158
158
# whitening matrix ** saving real whitening matrix doesn't work with phy currently
159
- whitening_mat = ops ['Wrot' ]. cpu (). numpy ()
160
- np .save ((results_dir / 'whitening_mat_dat.npy' ), whitening_mat )
159
+ whitening_mat = ops ['Wrot' ]
160
+ np .save ((results_dir / 'whitening_mat_dat.npy' ), whitening_mat . cpu () )
161
161
# NOTE: commented out for reference, this was different in KS 2.5 because
162
162
# the binary file was already whitened.
163
163
# whitening_mat = 0.005 * np.eye(len(chan_map), dtype='float32')
164
- whitening_mat_inv = np .linalg .inv (whitening_mat + 1e-5 * np .eye (whitening_mat .shape [0 ]))
165
- np .save ((results_dir / 'whitening_mat.npy' ), whitening_mat )
166
- np .save ((results_dir / 'whitening_mat_inv.npy' ), whitening_mat_inv )
164
+ whitening_mat_inv = torch .inverse (
165
+ whitening_mat
166
+ + 1e-5 * torch .eye (whitening_mat .shape [0 ]).to (whitening_mat .device )
167
+ )
168
+ #whitening_mat_inv = np.linalg.inv(whitening_mat + 1e-5 * np.eye(whitening_mat.shape[0]))
169
+ np .save ((results_dir / 'whitening_mat.npy' ), whitening_mat .cpu ())
170
+ np .save ((results_dir / 'whitening_mat_inv.npy' ), whitening_mat_inv .cpu ())
167
171
168
172
# spike properties
169
173
spike_times = st [:,0 ].astype ('int64' ) + imin # shift by minimum sample index
@@ -191,13 +195,19 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
191
195
192
196
# template properties
193
197
similar_templates = CCG .similarity (Wall , ops ['wPCA' ].contiguous (), nt = ops ['nt' ])
194
- template_amplitudes = ((Wall ** 2 ).sum (axis = (- 2 ,- 1 ))** 0.5 ).cpu ().numpy ()
198
+ temp_amplitudes = ((Wall ** 2 ).sum (axis = (- 2 ,- 1 ))** 0.5 ).cpu ().numpy ()
195
199
templates = (Wall .unsqueeze (- 1 ).cpu () * ops ['wPCA' ].cpu ()).sum (axis = - 2 ).numpy ()
196
200
templates = templates .transpose (0 ,2 ,1 )
201
+ # normalize templates by amplitude
202
+ templates = templates / temp_amplitudes [:, np .newaxis , np .newaxis ]
197
203
templates_ind = np .tile (np .arange (Wall .shape [1 ])[np .newaxis , :], (templates .shape [0 ],1 ))
198
204
np .save ((results_dir / 'similar_templates.npy' ), similar_templates )
199
205
np .save ((results_dir / 'templates.npy' ), templates )
200
206
np .save ((results_dir / 'templates_ind.npy' ), templates_ind )
207
+ # get unwhitened template amplitudes to use as cluster_Amplitudes
208
+ iwrot = whitening_mat_inv .to (Wall .device )
209
+ unwhitened = torch .einsum ('jk, ikl -> ijl' , iwrot , Wall )
210
+ template_amplitudes = ((unwhitened ** 2 ).sum (axis = (- 2 ,- 1 ))** 0.5 ).cpu ().numpy ()
201
211
202
212
# contamination ratio
203
213
acg_threshold = ops ['settings' ]['acg_threshold' ]
0 commit comments