@@ -243,6 +243,73 @@ def test_slice_recording(self, time_type, bounds):
243243
244244 assert np .allclose (rec_slice .get_times (0 ), all_times [0 ][start_frame :end_frame ], rtol = 0 , atol = 1e-8 )
245245
246+ def test_get_durations (self , time_vector_recording , t_start_recording ):
247+ """
248+ Test the `get_durations` functions that return the total duration
249+ for a segment. Test that it is correct after adding both `t_start`
250+ or `time_vector` to the recording.
251+ """
252+ raw_recording , tvector_recording , all_time_vectors = time_vector_recording
253+ _ , tstart_recording , all_t_starts = t_start_recording
254+
255+ ts = 1 / raw_recording .get_sampling_frequency ()
256+
257+ all_raw_durations = []
258+ all_vector_durations = []
259+ for segment_index in range (raw_recording .get_num_segments ()):
260+
261+ # Test before `t_start` and `t_start` (`t_start` is just an offset,
262+ # should not affect duration).
263+ raw_duration = all_t_starts [segment_index ][- 1 ] - all_t_starts [segment_index ][0 ] + ts
264+
265+ assert np .isclose (raw_recording .get_duration (segment_index ), raw_duration , rtol = 0 , atol = 1e-8 )
266+ assert np .isclose (tstart_recording .get_duration (segment_index ), raw_duration , rtol = 0 , atol = 1e-8 )
267+
268+ # Test the duration from the time vector.
269+ vector_duration = all_time_vectors [segment_index ][- 1 ] - all_time_vectors [segment_index ][0 ] + ts
270+
271+ assert tvector_recording .get_duration (segment_index ) == vector_duration
272+
273+ all_raw_durations .append (raw_duration )
274+ all_vector_durations .append (vector_duration )
275+
276+ # Finally test the total recording duration
277+ assert np .isclose (tstart_recording .get_total_duration (), sum (all_raw_durations ), rtol = 0 , atol = 1e-8 )
278+ assert np .isclose (tvector_recording .get_total_duration (), sum (all_vector_durations ), rtol = 0 , atol = 1e-8 )
279+
280+ def test_sorting_analyzer_get_durations_from_recording (self , time_vector_recording ):
281+ """
282+ Test that when a recording is set on `sorting_analyzer`, the
283+ total duration is propagated from the recording to the
284+ `sorting_analyzer.get_total_duration()` function.
285+ """
286+ _ , times_recording , _ = time_vector_recording
287+
288+ sorting = si .generate_sorting (
289+ durations = [times_recording .get_duration (s ) for s in range (times_recording .get_num_segments ())]
290+ )
291+ sorting_analyzer = si .create_sorting_analyzer (sorting , recording = times_recording )
292+
293+ assert np .array_equal (sorting_analyzer .get_total_duration (), times_recording .get_total_duration ())
294+
295+ def test_sorting_analyzer_get_durations_no_recording (self , time_vector_recording ):
296+ """
297+ Test when the `sorting_analzyer` does not have a recording set,
298+ the total duration is calculated on the fly from num samples and
299+ sampling frequency (thus matching `raw_recording` with no times set
300+ that uses the same method to calculate the total duration).
301+ """
302+ raw_recording , _ , _ = time_vector_recording
303+
304+ sorting = si .generate_sorting (
305+ durations = [raw_recording .get_duration (s ) for s in range (raw_recording .get_num_segments ())]
306+ )
307+ sorting_analyzer = si .create_sorting_analyzer (sorting , recording = raw_recording )
308+
309+ sorting_analyzer ._recording = None
310+
311+ assert np .array_equal (sorting_analyzer .get_total_duration (), raw_recording .get_total_duration ())
312+
246313 # Helpers ####
247314 def _check_times_match (self , recording , all_times ):
248315 """
0 commit comments