88from scipy import stats
99
1010# TODO: spike_times -> spike_indexes
11+ """
12+ Notes
13+ -----
14+ - not everything is used for current purposes
15+ - things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude.
16+ """
1117
1218
1319def compute_spike_amplitude_and_depth (
@@ -75,53 +81,58 @@ def compute_spike_amplitude_and_depth(
7581
7682 localised_template_by_spike = np .isin (params ["spike_templates" ], localised_templates )
7783
78- params ["spike_templates" ] = params ["spike_templates" ][localised_template_by_spike ]
79- params ["spike_indexes" ] = params ["spike_indexes" ][localised_template_by_spike ]
80- params ["spike_clusters" ] = params ["spike_clusters" ][localised_template_by_spike ]
81- params ["temp_scaling_amplitudes" ] = params ["temp_scaling_amplitudes" ][localised_template_by_spike ]
82- params ["pc_features" ] = params ["pc_features" ][localised_template_by_spike ]
84+ _strip_spikes (params , localised_template_by_spike )
8385
8486 # Compute spike depths
85- pc_features = params ["pc_features" ][:, 0 , :]
87+ pc_features = params ["pc_features" ][:, 0 , :] # Do this compute
8688 pc_features [pc_features < 0 ] = 0
8789
88- # Get the channel indexes corresponding to the 32 channels from the PC.
89- spike_features_indices = params ["pc_features_indices" ][params ["spike_templates" ], :]
90+ # Some spikes do not load at all onto the first PC. To avoid biasing the
91+ # dataset by removing these, we repeat the above for the next PC,
92+ # to compute distances for neurons that do not load onto the 1st PC.
93+ # This is not ideal at all, it would be much better to a) find the
94+ # max value for each channel on each of the PCs (i.e. basis vectors).
95+ # Then recompute the estimated waveform peak on each channel by
96+ # summing the PCs by their respective weights. However, the PC basis
97+ # vectors themselves do not appear to be output by KS.
98+ no_pc1_signal_spikes = np .where (np .sum (pc_features , axis = 1 ) == 0 )
99+
100+ pc_features_2 = params ["pc_features" ][:, 1 , :]
101+ pc_features_2 [pc_features_2 < 0 ] = 0
90102
91- ycoords = params ["channel_positions" ][:, 1 ]
92- spike_feature_ycoords = ycoords [spike_features_indices ]
103+ pc_features [no_pc1_signal_spikes ] = pc_features_2 [no_pc1_signal_spikes ]
93104
94- spike_depths = np .sum (spike_feature_ycoords * pc_features ** 2 , axis = 1 ) / np .sum (pc_features ** 2 , axis = 1 )
105+ if any (np .sum (pc_features , axis = 1 ) == 0 ):
106+ raise RuntimeError (
107+ "Some spikes do not load at all onto the first"
108+ "or second principal component. It is necessary"
109+ "to extend this code section to handle more components."
110+ )
95111
112+ # Get the channel indexes corresponding to the 32 channels from the PC.
113+ spike_features_indices = params ["pc_features_indices" ][params ["spike_templates" ], :]
114+
115+ # Compute the spike locations as the center of mass of the PC scores
96116 spike_feature_coords = params ["channel_positions" ][spike_features_indices , :]
97117 norm_weights = pc_features / np .sum (pc_features , axis = 1 )[:, np .newaxis ] # TOOD: see why they use square
98- weighted_locs = spike_feature_coords * norm_weights [:, :, np .newaxis ]
99- weighted_locs = np .sum (weighted_locs , axis = 1 )
118+ spike_locations = spike_feature_coords * norm_weights [:, :, np .newaxis ]
119+ spike_locations = np .sum (spike_locations , axis = 1 )
120+
121+ # TODO: now max site per spike is computed from PCs, not as the channel max site as previous
122+ spike_sites = spike_features_indices [np .arange (spike_features_indices .shape [0 ]), np .argmax (norm_weights , axis = 1 )]
123+
100124 # Amplitude is calculated for each spike as the template amplitude
101125 # multiplied by the `template_scaling_amplitudes`.
102-
103- # Compute amplitudes, scale if required and drop un-localised spikes before returning.
104- spike_amplitudes , _ , _ , _ , unwhite_templates , * _ = get_template_info_and_spike_amplitudes (
126+ template_amplitudes_unscaled , * _ = get_unwhite_template_info (
105127 params ["templates" ],
106128 params ["whitening_matrix_inv" ],
107- ycoords ,
108- params ["spike_templates" ],
109- params ["temp_scaling_amplitudes" ],
129+ params ["channel_positions" ],
110130 )
131+ spike_amplitudes = template_amplitudes_unscaled [params ["spike_templates" ]] * params ["temp_scaling_amplitudes" ]
111132
112133 if gain is not None :
113134 spike_amplitudes *= gain
114135
115- max_site = np .argmax (
116- np .max (np .abs (templates ), axis = 1 ), axis = 1
117- ) # TODO: combine this with above function. Maybe the above function can be templates only, and everything spike-related is here.
118- max_site = np .argmax (np .max (np .abs (unwhite_templates ), axis = 1 ), axis = 1 )
119- spike_sites = max_site [params ["spike_templates" ]]
120-
121- # TODO: here the max site is the same for all spikes from the same template.
122- # is this the case for spikeinterface? Should we estimate max-site per spike from
123- # the PCs?
124-
125136 if localised_spikes_only :
126137 # Interpolate the channel ids to location.
127138 # Remove spikes > 5 um from average position
@@ -130,23 +141,32 @@ def compute_spike_amplitude_and_depth(
130141 # TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
131142 # 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
132143 # 3) just use depth. Probably go for that. check with others.
133- spike_depths = weighted_locs [:, 1 ]
144+ spike_depths = spike_locations [:, 1 ]
134145 b = stats .linregress (spike_depths , spike_sites ).slope
135146 i = np .abs (spike_sites - b * spike_depths ) <= 5 # TODO: need to expose this
136147
137148 params ["spike_indexes" ] = params ["spike_indexes" ][i ]
138149 spike_amplitudes = spike_amplitudes [i ]
139- weighted_locs = weighted_locs [i , :]
150+ spike_locations = spike_locations [i , :]
151+
152+ return params ["spike_indexes" ], spike_amplitudes , spike_locations , spike_sites
140153
141- return params ["spike_indexes" ], spike_amplitudes , weighted_locs , spike_sites # TODO: rename everything
142154
155+ def _strip_spikes_in_place (params , indices ):
156+ """ """
157+ params ["spike_templates" ] = params ["spike_templates" ][
158+ indices
159+ ] # TODO: make an function for this. because we do this a lot
160+ params ["spike_indexes" ] = params ["spike_indexes" ][indices ]
161+ params ["spike_clusters" ] = params ["spike_clusters" ][indices ]
162+ params ["temp_scaling_amplitudes" ] = params ["temp_scaling_amplitudes" ][indices ]
163+ params ["pc_features" ] = params ["pc_features" ][indices ] # TODO: be conciststetn! change indees to indices
143164
144- def get_template_info_and_spike_amplitudes (
165+
166+ def get_unwhite_template_info (
145167 templates : np .ndarray ,
146168 inverse_whitening_matrix : np .ndarray ,
147- ycoords : np .ndarray ,
148- spike_templates : np .ndarray ,
149- template_scaling_amplitudes : np .ndarray ,
169+ channel_positions : np .ndarray ,
150170) -> tuple [np .ndarray , ...]:
151171 """
152172 Calculate the amplitude and depths of (unwhitened) templates and spikes.
@@ -163,28 +183,20 @@ def get_template_info_and_spike_amplitudes(
163183 inverse_whitening_matrix: np.ndarray
164184 Inverse of the whitening matrix used in KS preprocessing, used to
165185 unwhiten templates.
166- ycoords : np.ndarray
167- (num_channels,) array of the y-axis (depth) channel positions.
168- spike_templates : np.ndarray
169- (num_spikes,) array indicating the template associated with each spike.
170- template_scaling_amplitudes : np.ndarray
171- (num_spikes,) array holding the scaling amplitudes, by which the
172- template was scaled to match each spike.
186+ channel_positions : np.ndarray
187+ (num_channels, 2) array of the x, y channel positions.
173188
174189 Returns
175190 -------
176- spike_amplitudes : np.ndarray
177- (num_spikes,) array of the amplitude of each spike.
178- spike_depths : np.ndarray
179- (num_spikes,) array of the depth (probe y-axis) of each spike. Note
180- this is just the template depth for each spike (i.e. depth of all spikes
181- from the same cluster are identical).
182- template_amplitudes : np.ndarray
183- (num_templates,) Amplitude of each template, calculated as average of spike amplitudes.
184- template_depths : np.ndarray
185- (num_templates,) array of the depth of each template.
191+ template_amplitudes_unscaled : np.ndarray
192+ (num_templates,) array of the unscaled tempalte amplitudes. These can be
193+ used to calculate spike amplitude with `template_amplitude_scalings`.
194+ template_locations : np.ndarray
195+ (num_templates, 2) array of the x, y positions (center of mass) of each template.
186196 unwhite_templates : np.ndarray
187197 Unwhitened templates (num_clusters, num_samples, num_channels).
198+ template_max_site : np.array
199+ The maximum loading spike for the unwhitened template.
188200 trough_peak_durations : np.ndarray
189201 (num_templates, ) array of durations from trough to peak for each template waveform
190202 waveforms : np.ndarray
@@ -195,43 +207,31 @@ def get_template_info_and_spike_amplitudes(
195207 for idx , template in enumerate (templates ):
196208 unwhite_templates [idx , :, :] = templates [idx , :, :] @ inverse_whitening_matrix
197209
198- # First, calculate the depth of each template from the amplitude
199- # on each channel by the center of mass method.
200-
201210 # Take the max amplitude for each channel, then use the channel
202- # with most signal as template amplitude. Zero any small channel amplitudes.
211+ # with most signal as template amplitude.
203212 template_amplitudes_per_channel = np .max (unwhite_templates , axis = 1 ) - np .min (unwhite_templates , axis = 1 )
204213
205214 template_amplitudes_unscaled = np .max (template_amplitudes_per_channel , axis = 1 )
206215
207- threshold_values = 0.3 * template_amplitudes_unscaled
208- template_amplitudes_per_channel [template_amplitudes_per_channel < threshold_values [:, np .newaxis ]] = 0
216+ # Zero any small channel amplitudes
217+ # threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree?
218+ # template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0
209219
210220 # Calculate the template depth as the center of mass based on channel amplitudes
211- template_depths = np .sum (template_amplitudes_per_channel * ycoords [np .newaxis , :], axis = 1 ) / np .sum (
212- template_amplitudes_per_channel , axis = 1
213- )
214-
215- # Next, find the depth of each spike based on its template. Recompute the template
216- # amplitudes as the average of the spike amplitudes ('since
217- # tempScalingAmps are equal mean for all templates')
218- spike_amplitudes = template_amplitudes_unscaled [spike_templates ] * template_scaling_amplitudes
219-
220- # Take the average of all spike amplitudes to get actual template amplitudes
221- # (since tempScalingAmps are equal mean for all templates)
222- num_indices = templates .shape [0 ]
223- sum_per_index = np .zeros (num_indices , dtype = np .float64 )
224- np .add .at (sum_per_index , spike_templates , spike_amplitudes )
225- counts = np .bincount (spike_templates , minlength = num_indices )
226- template_amplitudes = np .divide (sum_per_index , counts , out = np .zeros_like (sum_per_index ), where = counts != 0 )
221+ weights = template_amplitudes_per_channel / np .sum (template_amplitudes_per_channel , axis = 1 )[:, np .newaxis ]
222+ template_locations = weights @ channel_positions
227223
228224 # Get channel with the largest amplitude (take that as the waveform)
229- max_site = np .argmax (np .max (np .abs (templates ), axis = 1 ), axis = 1 )
225+ template_max_site = np .argmax (
226+ np .max (np .abs (unwhite_templates ), axis = 1 ), axis = 1
227+ ) # TODO: i changed this to use unwhitened templates instead of templates. This okay?
230228
231229 # Use template channel with max signal as waveform
232- waveforms = np .empty (templates .shape [:2 ])
233- for idx , template in enumerate (templates ):
234- waveforms [idx , :] = templates [idx , :, max_site [idx ]]
230+ waveforms = np .empty (
231+ unwhite_templates .shape [:2 ]
232+ ) # TODO: i changed this to use unwhitened templates instead of templates. This okay?
233+ for idx , template in enumerate (unwhite_templates ):
234+ waveforms [idx , :] = unwhite_templates [idx , :, template_max_site [idx ]]
235235
236236 # Get trough-to-peak time for each template. Find the trough as the
237237 # minimum signal for the template waveform. The duration (in
@@ -244,15 +244,26 @@ def get_template_info_and_spike_amplitudes(
244244 trough_peak_durations [idx ] = np .argmax (tmp_max [waveform_trough [idx ] :])
245245
246246 return (
247- spike_amplitudes ,
248- template_depths ,
249- template_amplitudes ,
247+ template_amplitudes_unscaled ,
248+ template_locations ,
249+ template_max_site ,
250250 unwhite_templates ,
251251 trough_peak_durations ,
252252 waveforms ,
253253 )
254254
255255
256+ def compute_template_amplitudes_from_spikes ():
257+ # Take the average of all spike amplitudes to get actual template amplitudes
258+ # (since tempScalingAmps are equal mean for all templates)
259+ num_indices = templates .shape [0 ]
260+ sum_per_index = np .zeros (num_indices , dtype = np .float64 )
261+ np .add .at (sum_per_index , spike_templates , spike_amplitudes )
262+ counts = np .bincount (spike_templates , minlength = num_indices )
263+ template_amplitudes = np .divide (sum_per_index , counts , out = np .zeros_like (sum_per_index ), where = counts != 0 )
264+ return template_amplitudes
265+
266+
256267def _load_ks_dir (sorter_output : Path , exclude_noise : bool = True , load_pcs : bool = False ) -> dict :
257268 """
258269 Loads the output of Kilosort into a `params` dict.
0 commit comments