@@ -76,19 +76,18 @@ def select_peaks(
7676
7777 selected_indices = select_peak_indices (peaks , method = method , seed = seed , ** method_kwargs )
7878 selected_peaks = peaks [selected_indices ]
79+ num_segments = len (np .unique (selected_peaks ["segment_index" ]))
7980
8081 if margin is not None :
8182 to_keep = np .zeros (len (selected_peaks ), dtype = bool )
82- offset = 0
83- for segment_index in range (recording .get_num_segments ()):
84- duration = recording .get_num_frames (segment_index )
83+ for segment_index in range (num_segments ):
84+ num_samples_in_segment = recording .get_num_samples (segment_index )
8585 i0 , i1 = np .searchsorted (selected_peaks ["segment_index" ], [segment_index , segment_index + 1 ])
86- while selected_peaks ["sample_index" ][i0 ] <= margin [0 ] + offset :
86+ while selected_peaks ["sample_index" ][i0 ] <= margin [0 ]:
8787 i0 += 1
88- while selected_peaks ["sample_index" ][i1 - 1 ] >= (duration - margin [1 ]) + offset :
88+ while selected_peaks ["sample_index" ][i1 - 1 ] >= (num_samples_in_segment - margin [1 ]):
8989 i1 -= 1
9090 to_keep [i0 :i1 ] = True
91- offset += duration
9291 selected_indices = selected_indices [to_keep ]
9392 selected_peaks = peaks [selected_indices ]
9493
@@ -284,7 +283,9 @@ def select_peak_indices(peaks, method, seed, **method_kwargs):
284283 )
285284
286285 selected_indices = np .concatenate (selected_indices )
287- selected_indices = selected_indices [np .argsort (peaks [selected_indices ]["sample_index" ])]
286+ selected_indices = selected_indices [
287+ np .lexsort ((peaks [selected_indices ]["sample_index" ], peaks [selected_indices ]["segment_index" ]))
288+ ]
288289 return selected_indices
289290
290291
0 commit comments