diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index f501e0177..d58114121 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -201,6 +201,14 @@ def _generic_time_based_sampler( sampling_range_end, # excluded seconds_between_clip_starts, ) + # As mentioned in the docs, torch.arange may return values + # equal to or above `end` because of floating precision errors. + # Here, we manually ensure all values are strictly lower than `sample_range_end` + if clip_start_seconds[-1] >= sampling_range_end: + clip_start_seconds = clip_start_seconds[ + clip_start_seconds < sampling_range_end + ] + num_clips = len(clip_start_seconds) all_clips_timestamps = _build_all_clips_timestamps( diff --git a/test/test_samplers.py b/test/test_samplers.py index 938be0d91..10c529062 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -19,7 +19,7 @@ from torchcodec.samplers._index_based import _build_all_clips_indices from torchcodec.samplers._time_based import _build_all_clips_timestamps -from .utils import assert_frames_equal, NASA_VIDEO +from .utils import assert_frames_equal, H265_10BITS, NASA_VIDEO def _assert_output_type_and_shapes( @@ -698,3 +698,36 @@ def test_build_all_clips_timestamps( assert all(isinstance(timestamp, float) for timestamp in all_clips_timestamps) assert len(all_clips_timestamps) == len(clip_start_seconds) * NUM_FRAMES_PER_CLIP assert all_clips_timestamps == expected_all_clips_timestamps + + +@pytest.mark.parametrize("policy", ("repeat_last", "wrap", "error")) +def test_floating_point_precision_in_clips_at_regular_timestamps(policy): + # Test that floating point precision errors in torch.arange do not return empty clips. + # Using 1/3 would cause arange to include sampling_range_end, which gets filtered out + # in _build_all_clips_timestamps, leaving clips with no frames. + # The fix rounds seconds_between_clip_starts to prevent this. + seconds_between_clip_starts = 1 / 3 - 1e-9 + + decoder = VideoDecoder(H265_10BITS.path) # Video is 1 second long + # Set sampling range so that last clip will have frame timestamp ≈ end_stream_seconds + sampling_range_start = 0 + sampling_range_end = decoder.metadata.end_stream_seconds + seconds_between_frames = 1 + num_frames_per_clip = 1 + + clips = clips_at_regular_timestamps( + decoder, + seconds_between_clip_starts=seconds_between_clip_starts, + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + num_frames_per_clip=num_frames_per_clip, + seconds_between_frames=seconds_between_frames, + policy=policy, + ) + + # Ensure frame PTS can be decoded + for clip in clips: + frames = decoder.get_frames_played_at(seconds=clip.pts_seconds.tolist()) + assert isinstance(frames, FrameBatch) + assert frames.data.shape[0] == len(clip.pts_seconds) + assert len(clip.pts_seconds) == num_frames_per_clip