@@ -49,7 +49,7 @@ def compute_spike_amplitude_and_depth(
4949
5050 Notes
5151 -----
52- In `_template_positions_amplitudes ` spike depths is calculated as simply the template
52+ In `get_template_info_and_spike_amplitudes ` spike depths is calculated as simply the template
5353 depth, for each spike (so it is the same for all spikes in a cluster). Here we need
5454 to find the depth of each individual spike, using its low-dimensional projection.
5555 `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike.
@@ -101,7 +101,7 @@ def compute_spike_amplitude_and_depth(
101101 # multiplied by the `template_scaling_amplitudes`.
102102
103103 # Compute amplitudes, scale if required and drop un-localised spikes before returning.
104- spike_amplitudes , _ , _ , _ , unwhite_templates , * _ = _template_positions_amplitudes (
104+ spike_amplitudes , _ , _ , _ , unwhite_templates , * _ = get_template_info_and_spike_amplitudes (
105105 params ["templates" ],
106106 params ["whitening_matrix_inv" ],
107107 ycoords ,
@@ -112,9 +112,16 @@ def compute_spike_amplitude_and_depth(
112112 if gain is not None :
113113 spike_amplitudes *= gain
114114
115+ max_site = np .argmax (
116+ np .max (np .abs (templates ), axis = 1 ), axis = 1
117+ ) # TODO: combine this with above function. Maybe the above function can be templates only, and everything spike-related is here.
115118 max_site = np .argmax (np .max (np .abs (unwhite_templates ), axis = 1 ), axis = 1 )
116119 spike_sites = max_site [params ["spike_templates" ]]
117120
121+ # TODO: here the max site is the same for all spikes from the same template.
122+ # is this the case for spikeinterface? Should we estimate max-site per spike from
123+ # the PCs?
124+
118125 if localised_spikes_only :
119126 # Interpolate the channel ids to location.
120127 # Remove spikes > 5 um from average position
@@ -134,45 +141,7 @@ def compute_spike_amplitude_and_depth(
134141 return params ["spike_indexes" ], spike_amplitudes , weighted_locs , spike_sites # TODO: rename everything
135142
136143
137- def _filter_large_amplitude_spikes (
138- spike_times : np .ndarray ,
139- spike_amplitudes : np .ndarray ,
140- spike_depths : np .ndarray ,
141- large_amplitude_only_segment_size ,
142- ) -> tuple [np .ndarray , ...]:
143- """
144- Return spike properties with only the largest-amplitude spikes included. The probe
145- is split into egments, and within each segment the mean and std computed.
146- Any spike less than 1.5x the standard deviation in amplitude of it's segment is excluded
147- Splitting the probe is only done for the exclusion step, the returned array are flat.
148-
149- Takes as input arrays `spike_times`, `spike_depths` and `spike_amplitudes` and returns
150- copies of these arrays containing only the large amplitude spikes.
151- """
152- spike_bool = np .zeros_like (spike_amplitudes , dtype = bool )
153-
154- segment_size_um = large_amplitude_only_segment_size
155- probe_segments_left_edges = np .arange (np .floor (spike_depths .max () / segment_size_um ) + 1 ) * segment_size_um
156-
157- for segment_left_edge in probe_segments_left_edges :
158- segment_right_edge = segment_left_edge + segment_size_um
159-
160- spikes_in_seg = np .where (np .logical_and (spike_depths >= segment_left_edge , spike_depths < segment_right_edge ))[
161- 0
162- ]
163- spike_amps_in_seg = spike_amplitudes [spikes_in_seg ]
164- is_high_amplitude = spike_amps_in_seg > np .mean (spike_amps_in_seg ) + 1.5 * np .std (spike_amps_in_seg , ddof = 1 )
165-
166- spike_bool [spikes_in_seg ] = is_high_amplitude
167-
168- spike_times = spike_times [spike_bool ]
169- spike_amplitudes = spike_amplitudes [spike_bool ]
170- spike_depths = spike_depths [spike_bool ]
171-
172- return spike_times , spike_amplitudes , spike_depths
173-
174-
175- def _template_positions_amplitudes (
144+ def get_template_info_and_spike_amplitudes (
176145 templates : np .ndarray ,
177146 inverse_whitening_matrix : np .ndarray ,
178147 ycoords : np .ndarray ,
@@ -256,9 +225,6 @@ def _template_positions_amplitudes(
256225 counts = np .bincount (spike_templates , minlength = num_indices )
257226 template_amplitudes = np .divide (sum_per_index , counts , out = np .zeros_like (sum_per_index ), where = counts != 0 )
258227
259- # Each spike's depth is the depth of its template
260- spike_depths = template_depths [spike_templates ]
261-
262228 # Get channel with the largest amplitude (take that as the waveform)
263229 max_site = np .argmax (np .max (np .abs (templates ), axis = 1 ), axis = 1 )
264230
@@ -279,7 +245,6 @@ def _template_positions_amplitudes(
279245
280246 return (
281247 spike_amplitudes ,
282- spike_depths ,
283248 template_depths ,
284249 template_amplitudes ,
285250 unwhite_templates ,
0 commit comments