Skip to content

Commit 6b8c540

Browse files
committed
Add test for sorting analzyer total duration.
1 parent 31b94df commit 6b8c540

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

src/spikeinterface/core/tests/test_time_handling.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,38 @@ def test_get_durations(self, time_vector_recording, t_start_recording):
277277
assert np.isclose(tstart_recording.get_total_duration(), sum(all_raw_durations), rtol=0, atol=1e-8)
278278
assert np.isclose(tvector_recording.get_total_duration(), sum(all_vector_durations), rtol=0, atol=1e-8)
279279

280-
# def test_sorting_analyzer_get_durations(self, time_vector_recording):
281-
# """ """
282-
# breakpoint()
283-
# sorting = si.generate_sorting()
284-
# sorting_analyzer = si.create_sorting_analyzer(sorting, recording=None)
285-
# si.sorting_an
280+
def test_sorting_analyzer_get_durations_from_recording(self, time_vector_recording):
281+
"""
282+
Test that when a recording is set on `sorting_analyzer`, the
283+
total duration is propagated from the recording to the
284+
`sorting_analyzer.get_total_duration()` function.
285+
"""
286+
_, times_recording, _ = time_vector_recording
287+
288+
sorting = si.generate_sorting(
289+
durations=[times_recording.get_duration(s) for s in range(times_recording.get_num_segments())]
290+
)
291+
sorting_analyzer = si.create_sorting_analyzer(sorting, recording=times_recording)
292+
293+
assert np.array_equal(sorting_analyzer.get_total_duration(), times_recording.get_total_duration())
294+
295+
def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording):
296+
"""
297+
Test when the `sorting_analzyer` does not have a recording set,
298+
the total duration is calculated on the fly from num samples and
299+
sampling frequency (thus matching `raw_recording` with no times set
300+
that uses the same method to calculate the total duration).
301+
"""
302+
raw_recording, _, _ = time_vector_recording
303+
304+
sorting = si.generate_sorting(
305+
durations=[raw_recording.get_duration(s) for s in range(raw_recording.get_num_segments())]
306+
)
307+
sorting_analyzer = si.create_sorting_analyzer(sorting, recording=raw_recording)
308+
309+
sorting_analyzer._recording = None
310+
311+
assert np.array_equal(sorting_analyzer.get_total_duration(), raw_recording.get_total_duration())
286312

287313
# Helpers ####
288314
def _check_times_match(self, recording, all_times):

0 commit comments

Comments
 (0)