Skip to content

Commit 31b94df

Browse files
committed
Start adding tests.
1 parent 27360da commit 31b94df

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

src/spikeinterface/core/tests/test_time_handling.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,47 @@ def test_slice_recording(self, time_type, bounds):
243243

244244
assert np.allclose(rec_slice.get_times(0), all_times[0][start_frame:end_frame], rtol=0, atol=1e-8)
245245

246+
def test_get_durations(self, time_vector_recording, t_start_recording):
247+
"""
248+
Test the `get_durations` functions that return the total duration
249+
for a segment. Test that it is correct after adding both `t_start`
250+
or `time_vector` to the recording.
251+
"""
252+
raw_recording, tvector_recording, all_time_vectors = time_vector_recording
253+
_, tstart_recording, all_t_starts = t_start_recording
254+
255+
ts = 1 / raw_recording.get_sampling_frequency()
256+
257+
all_raw_durations = []
258+
all_vector_durations = []
259+
for segment_index in range(raw_recording.get_num_segments()):
260+
261+
# Test before `t_start` and `t_start` (`t_start` is just an offset,
262+
# should not affect duration).
263+
raw_duration = all_t_starts[segment_index][-1] - all_t_starts[segment_index][0] + ts
264+
265+
assert np.isclose(raw_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)
266+
assert np.isclose(tstart_recording.get_duration(segment_index), raw_duration, rtol=0, atol=1e-8)
267+
268+
# Test the duration from the time vector.
269+
vector_duration = all_time_vectors[segment_index][-1] - all_time_vectors[segment_index][0] + ts
270+
271+
assert tvector_recording.get_duration(segment_index) == vector_duration
272+
273+
all_raw_durations.append(raw_duration)
274+
all_vector_durations.append(vector_duration)
275+
276+
# Finally test the total recording duration
277+
assert np.isclose(tstart_recording.get_total_duration(), sum(all_raw_durations), rtol=0, atol=1e-8)
278+
assert np.isclose(tvector_recording.get_total_duration(), sum(all_vector_durations), rtol=0, atol=1e-8)
279+
280+
# def test_sorting_analyzer_get_durations(self, time_vector_recording):
281+
# """ """
282+
# breakpoint()
283+
# sorting = si.generate_sorting()
284+
# sorting_analyzer = si.create_sorting_analyzer(sorting, recording=None)
285+
# si.sorting_an
286+
246287
# Helpers ####
247288
def _check_times_match(self, recording, all_times):
248289
"""

0 commit comments

Comments
 (0)