@@ -15,7 +15,10 @@ class TestTimeHandling:
1515 is generated on the fly. Both time representations are tested here.
1616 """
1717
18- # Fixtures #####
18+ # #########################################################################
19+ # Fixtures
20+ # #########################################################################
21+
1922 @pytest .fixture (scope = "session" )
2023 def time_vector_recording (self ):
2124 """
@@ -95,7 +98,10 @@ def _get_fixture_data(self, request, fixture_name):
9598 raw_recording , times_recording , all_times = time_recording_fixture
9699 return (raw_recording , times_recording , all_times )
97100
98- # Tests #####
101+ # #########################################################################
102+ # Tests
103+ # #########################################################################
104+
99105 def test_has_time_vector (self , time_vector_recording ):
100106 """
101107 Test the `has_time_vector` function returns `False` before
@@ -305,7 +311,87 @@ def test_sorting_analyzer_get_durations_no_recording(self, time_vector_recording
305311
306312 assert np .array_equal (sorting_analyzer .get_total_duration (), raw_recording .get_total_duration ())
307313
308- # Helpers ####
314+ @pytest .mark .parametrize ("fixture_name" , ["time_vector_recording" , "t_start_recording" ])
315+ @pytest .mark .parametrize ("shift" , [- 123.456 , 123.456 ])
316+ def test_shift_time_all_segments (self , request , fixture_name , shift ):
317+ """
318+ Shift the times in every segment using the `None` default, then
319+ check that every segment of the recording is shifted as expected.
320+ """
321+ _ , times_recording , all_times = self ._get_fixture_data (request , fixture_name )
322+
323+ num_segments , orig_seg_data = self ._store_all_times (times_recording )
324+
325+ times_recording .shift_times (shift ) # use default `segment_index=None`
326+
327+ for idx in range (num_segments ):
328+ assert np .allclose (
329+ orig_seg_data [idx ], times_recording .get_times (segment_index = idx ) - shift , rtol = 0 , atol = 1e-8
330+ )
331+
332+ @pytest .mark .parametrize ("fixture_name" , ["time_vector_recording" , "t_start_recording" ])
333+ @pytest .mark .parametrize ("shift" , [- 123.456 , 123.456 ])
334+ def test_shift_times_different_segments (self , request , fixture_name , shift ):
335+ """
336+ Shift each segment separately, and check the shifted segment only
337+ is shifted as expected.
338+ """
339+ _ , times_recording , all_times = self ._get_fixture_data (request , fixture_name )
340+
341+ num_segments , orig_seg_data = self ._store_all_times (times_recording )
342+
343+ # For each segment, shift the segment only and check the
344+ # times are updated as expected.
345+ for idx in range (num_segments ):
346+
347+ scaler = idx + 2
348+ times_recording .shift_times (shift * scaler , segment_index = idx )
349+
350+ assert np .allclose (
351+ orig_seg_data [idx ], times_recording .get_times (segment_index = idx ) - shift * scaler , rtol = 0 , atol = 1e-8
352+ )
353+
354+ # Just do a little check that we are not
355+ # accidentally changing some other segments,
356+ # which should remain unchanged at this point in the loop.
357+ if idx != num_segments - 1 :
358+ assert np .array_equal (orig_seg_data [idx + 1 ], times_recording .get_times (segment_index = idx + 1 ))
359+
360+ @pytest .mark .parametrize ("fixture_name" , ["time_vector_recording" , "t_start_recording" ])
361+ def test_save_and_load_time_shift (self , request , fixture_name , tmp_path ):
362+ """
363+ Save the shifted data and check the shift is propagated correctly.
364+ """
365+ _ , times_recording , all_times = self ._get_fixture_data (request , fixture_name )
366+
367+ shift = 100
368+ times_recording .shift_times (shift = shift )
369+
370+ times_recording .save (folder = tmp_path / "my_file" )
371+
372+ loaded_recording = si .load_extractor (tmp_path / "my_file" )
373+
374+ for idx in range (times_recording .get_num_segments ()):
375+ assert np .array_equal (
376+ times_recording .get_times (segment_index = idx ), loaded_recording .get_times (segment_index = idx )
377+ )
378+
379+ def _store_all_times (self , recording ):
380+ """
381+ Convenience function to store original times of all segments to a dict.
382+ """
383+ num_segments = recording .get_num_segments ()
384+ seg_data = {}
385+
386+ for idx in range (num_segments ):
387+ seg_data [idx ] = copy .deepcopy (recording .get_times (segment_index = idx ))
388+
389+ return num_segments , seg_data
390+
391+ # #########################################################################
392+ # Helpers
393+ # #########################################################################
394+
309395 def _check_times_match (self , recording , all_times ):
310396 """
311397 For every segment in a recording, check the `get_times()`
0 commit comments