Skip to content

Commit 26c84ea

Browse files
authored
get_spike_prototype can have NaN sometimes (#2980)
avoid get_spike_prototype can have NaN sometimes with margin in select_peaks
1 parent a3c4c8f commit 26c84ea

File tree

6 files changed

+47
-12
lines changed

6 files changed

+47
-12
lines changed

src/spikeinterface/sorters/internal/simplesorter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
112112

113113
ms_before = params["waveforms"]["ms_before"]
114114
ms_after = params["waveforms"]["ms_after"]
115+
nbefore = int(ms_before * sampling_frequency / 1000.0)
116+
nafter = int(ms_after * sampling_frequency / 1000.0)
115117

116118
# SVD for time compression
117-
few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000)
119+
120+
few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=5000, margin=(nbefore, nafter))
118121
few_wfs = extract_waveform_at_max_channel(
119122
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs
120123
)

src/spikeinterface/sortingcomponents/clustering/circus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def main_function(cls, recording, peaks, params):
9090
tmp_folder.mkdir(parents=True, exist_ok=True)
9191

9292
# SVD for time compression
93-
few_peaks = select_peaks(peaks, method="uniform", n_peaks=10000)
93+
few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=10000, margin=(nbefore, nafter))
9494
few_wfs = extract_waveform_at_max_channel(
9595
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **params["job_kwargs"]
9696
)

src/spikeinterface/sortingcomponents/clustering/tdc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,11 @@ def main_function(cls, recording, peaks, params):
7373
ms_before = params["waveforms"]["ms_before"]
7474
ms_after = params["waveforms"]["ms_after"]
7575

76+
nbefore = int(ms_before * sampling_frequency / 1000.0)
77+
nafter = int(ms_after * sampling_frequency / 1000.0)
78+
7679
# SVD for time compression
77-
few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000)
80+
few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=5000, margin=(nbefore, nafter))
7881
few_wfs = extract_waveform_at_max_channel(
7982
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs
8083
)

src/spikeinterface/sortingcomponents/peak_selection.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import numpy as np
77

88

9-
def select_peaks(peaks, method="uniform", seed=None, return_indices=False, **method_kwargs):
9+
def select_peaks(
10+
peaks, recording=None, method="uniform", seed=None, return_indices=False, margin=None, **method_kwargs
11+
):
1012
"""
1113
Method to select a subset of peaks from a set of peaks.
1214
Usually use for reducing computational foorptint of downstream methods.
@@ -28,6 +30,9 @@ def select_peaks(peaks, method="uniform", seed=None, return_indices=False, **met
2830
The seed for random generations
2931
return_indices: bool
3032
If True, return the indices of selection such that selected_peaks = peaks[selected_indices]
33+
margin : Margin in timesteps. default: None. Otherwise should be a tuple (nbefore, nafter)
34+
preventing peaks to be selected at the borders of the segments. A recording should be provided to get the duration
35+
of the segments
3136
3237
method_kwargs: dict of kwargs method
3338
Keyword arguments for the chosen method:
@@ -66,8 +71,27 @@ def select_peaks(peaks, method="uniform", seed=None, return_indices=False, **met
6671
return_indices is True.
6772
"""
6873

74+
if margin is not None:
75+
assert recording is not None, "recording should be provided if margin is not None"
76+
6977
selected_indices = select_peak_indices(peaks, method=method, seed=seed, **method_kwargs)
7078
selected_peaks = peaks[selected_indices]
79+
80+
if margin is not None:
81+
to_keep = np.zeros(len(selected_peaks), dtype=bool)
82+
offset = 0
83+
for segment_index in range(recording.get_num_segments()):
84+
duration = recording.get_num_frames(segment_index)
85+
i0, i1 = np.searchsorted(selected_peaks["segment_index"], [segment_index, segment_index + 1])
86+
while selected_peaks["sample_index"][i0] <= margin[0] + offset:
87+
i0 += 1
88+
while selected_peaks["sample_index"][i1 - 1] >= (duration - margin[1]) + offset:
89+
i1 -= 1
90+
to_keep[i0:i1] = True
91+
offset += duration
92+
selected_indices = selected_indices[to_keep]
93+
selected_peaks = peaks[selected_indices]
94+
7195
if return_indices:
7296
return selected_peaks, selected_indices
7397
else:

src/spikeinterface/sortingcomponents/tests/test_peak_selection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def test_select_peaks():
4545
selected_peaks.size <= n_peaks
4646
), "selected_peaks is not the right size when return_indices=False, select_per_channel=False"
4747

48+
selected_peaks = select_peaks(peaks, recording=recording, method=method, margin=(10, 10), **select_kwargs)
49+
assert (
50+
selected_peaks.size <= n_peaks
51+
), "selected_peaks is not the right size when return_indices=False, select_per_channel=False"
52+
4853
selected_peaks = select_peaks(peaks, method=method, select_per_channel=True, **select_kwargs)
4954
assert selected_peaks.size <= (
5055
n_peaks * recording.get_num_channels()

src/spikeinterface/sortingcomponents/tools.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,18 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j
7070

7171

7272
def get_prototype_spike(recording, peaks, ms_before=0.5, ms_after=0.5, nb_peaks=1000, **job_kwargs):
73-
if peaks.size > nb_peaks:
74-
idx = np.sort(np.random.choice(len(peaks), nb_peaks, replace=False))
75-
some_peaks = peaks[idx]
76-
else:
77-
some_peaks = peaks
78-
7973
nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
74+
nafter = int(ms_after * recording.sampling_frequency / 1000.0)
75+
76+
from spikeinterface.sortingcomponents.peak_selection import select_peaks
77+
78+
few_peaks = select_peaks(peaks, recording=recording, method="uniform", n_peaks=nb_peaks, margin=(nbefore, nafter))
8079

8180
waveforms = extract_waveform_at_max_channel(
82-
recording, some_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs
81+
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs
8382
)
84-
prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0)
83+
with np.errstate(divide="ignore", invalid="ignore"):
84+
prototype = np.median(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0)
8585
return prototype
8686

8787

0 commit comments

Comments
 (0)