@@ -76,18 +76,19 @@ def select_peaks(
7676
7777 selected_indices = select_peak_indices (peaks , method = method , seed = seed , ** method_kwargs )
7878 selected_peaks = peaks [selected_indices ]
79- num_segments = len (np .unique (selected_peaks ["segment_index" ]))
8079
8180 if margin is not None :
8281 to_keep = np .zeros (len (selected_peaks ), dtype = bool )
83- for segment_index in range (num_segments ):
84- num_samples_in_segment = recording .get_num_samples (segment_index )
82+ offset = 0
83+ for segment_index in range (recording .get_num_segments ()):
84+ duration = recording .get_num_frames (segment_index )
8585 i0 , i1 = np .searchsorted (selected_peaks ["segment_index" ], [segment_index , segment_index + 1 ])
86- while selected_peaks ["sample_index" ][i0 ] <= margin [0 ]:
86+ while selected_peaks ["sample_index" ][i0 ] <= margin [0 ] + offset :
8787 i0 += 1
88- while selected_peaks ["sample_index" ][i1 - 1 ] >= (num_samples_in_segment - margin [1 ]):
88+ while selected_peaks ["sample_index" ][i1 - 1 ] >= (duration - margin [1 ]) + offset :
8989 i1 -= 1
9090 to_keep [i0 :i1 ] = True
91+ offset += duration
9192 selected_indices = selected_indices [to_keep ]
9293 selected_peaks = peaks [selected_indices ]
9394
@@ -283,9 +284,7 @@ def select_peak_indices(peaks, method, seed, **method_kwargs):
283284 )
284285
285286 selected_indices = np .concatenate (selected_indices )
286- selected_indices = selected_indices [
287- np .lexsort ((peaks [selected_indices ]["sample_index" ], peaks [selected_indices ]["segment_index" ]))
288- ]
287+ selected_indices = selected_indices [np .argsort (peaks [selected_indices ]["sample_index" ])]
289288 return selected_indices
290289
291290
0 commit comments