@@ -277,12 +277,38 @@ def test_get_durations(self, time_vector_recording, t_start_recording):
277277 assert np .isclose (tstart_recording .get_total_duration (), sum (all_raw_durations ), rtol = 0 , atol = 1e-8 )
278278 assert np .isclose (tvector_recording .get_total_duration (), sum (all_vector_durations ), rtol = 0 , atol = 1e-8 )
279279
280- # def test_sorting_analyzer_get_durations(self, time_vector_recording):
281- # """ """
282- # breakpoint()
283- # sorting = si.generate_sorting()
284- # sorting_analyzer = si.create_sorting_analyzer(sorting, recording=None)
285- # si.sorting_an
280+ def test_sorting_analyzer_get_durations_from_recording (self , time_vector_recording ):
281+ """
282+ Test that when a recording is set on `sorting_analyzer`, the
283+ total duration is propagated from the recording to the
284+ `sorting_analyzer.get_total_duration()` function.
285+ """
286+ _ , times_recording , _ = time_vector_recording
287+
288+ sorting = si .generate_sorting (
289+ durations = [times_recording .get_duration (s ) for s in range (times_recording .get_num_segments ())]
290+ )
291+ sorting_analyzer = si .create_sorting_analyzer (sorting , recording = times_recording )
292+
293+ assert np .array_equal (sorting_analyzer .get_total_duration (), times_recording .get_total_duration ())
294+
295+ def test_sorting_analyzer_get_durations_no_recording (self , time_vector_recording ):
296+ """
297+ Test when the `sorting_analzyer` does not have a recording set,
298+ the total duration is calculated on the fly from num samples and
299+ sampling frequency (thus matching `raw_recording` with no times set
300+ that uses the same method to calculate the total duration).
301+ """
302+ raw_recording , _ , _ = time_vector_recording
303+
304+ sorting = si .generate_sorting (
305+ durations = [raw_recording .get_duration (s ) for s in range (raw_recording .get_num_segments ())]
306+ )
307+ sorting_analyzer = si .create_sorting_analyzer (sorting , recording = raw_recording )
308+
309+ sorting_analyzer ._recording = None
310+
311+ assert np .array_equal (sorting_analyzer .get_total_duration (), raw_recording .get_total_duration ())
286312
287313 # Helpers ####
288314 def _check_times_match (self , recording , all_times ):
0 commit comments