@@ -148,16 +148,15 @@ def lfp_timestamps(self):
148148 self ._lfp_timestamps = np .hstack ([s .times for s in self .lfp_analog_signals ])
149149 return self ._lfp_timestamps
150150
151- def extract_spike_waveforms (self , spikes , channel , n_wf = 500 , wf_win = (- 32 , 32 )):
151+ def extract_spike_waveforms (self , spikes , channel_ind , n_wf = 500 , wf_win = (- 32 , 32 )):
152152 """
153153 :param spikes: spike times (in second) to extract waveforms
154- :param channel : channel (name, not indices ) to extract waveforms
154+ :param channel_ind : channel indices (of meta['channels_ids'] ) to extract waveforms
155155 :param n_wf: number of spikes per unit to extract the waveforms
156156 :param wf_win: number of sample pre and post a spike
157157 :return: waveforms (sample x channel x spike)
158158 """
159- channel_ind = [np .where (self .ap_meta ['channels_ids' ] == chn )[0 ][0 ] for chn in channel ]
160- channel_bit_volts = self .ap_meta ['channels_gains' ][channel_ind ]
159+ channel_bit_volts = np .array (self .ap_meta ['channels_gains' ])[channel_ind ]
161160
162161 # ignore spikes at the beginning or end of raw data
163162 spikes = spikes [np .logical_and (spikes > (- wf_win [0 ] / self .ap_meta ['sample_rate' ]),
@@ -174,4 +173,4 @@ def extract_spike_waveforms(self, spikes, channel, n_wf=500, wf_win=(-32, 32)):
174173 for spk in spike_indices ])
175174 return spike_wfs
176175 else : # if no spike found, return NaN of size (sample x channel x 1)
177- return np .full ((len (range (* wf_win )), len (channel ), 1 ), np .nan )
176+ return np .full ((len (range (* wf_win )), len (channel_ind ), 1 ), np .nan )
0 commit comments