From 6358c4bb3ea8a0533674767d02ff02cea286798f Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Wed, 17 Sep 2025 16:55:37 -0400 Subject: [PATCH 1/7] round step in torch arange --- src/torchcodec/samplers/_time_based.py | 2 +- test/test_samplers.py | 35 +++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index f501e0177..e01e9d856 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -199,7 +199,7 @@ def _generic_time_based_sampler( clip_start_seconds = torch.arange( sampling_range_start, sampling_range_end, # excluded - seconds_between_clip_starts, + round(seconds_between_clip_starts, 6), ) num_clips = len(clip_start_seconds) diff --git a/test/test_samplers.py b/test/test_samplers.py index 938be0d91..36685bc46 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -19,7 +19,7 @@ from torchcodec.samplers._index_based import _build_all_clips_indices from torchcodec.samplers._time_based import _build_all_clips_timestamps -from .utils import assert_frames_equal, NASA_VIDEO +from .utils import assert_frames_equal, H265_10BITS, NASA_VIDEO def _assert_output_type_and_shapes( @@ -698,3 +698,36 @@ def test_build_all_clips_timestamps( assert all(isinstance(timestamp, float) for timestamp in all_clips_timestamps) assert len(all_clips_timestamps) == len(clip_start_seconds) * NUM_FRAMES_PER_CLIP assert all_clips_timestamps == expected_all_clips_timestamps + + +@pytest.mark.parametrize("policy", ("repeat_last", "wrap", "error")) +def test_floating_point_precision_in_clips_at_regular_timestamps(policy): + # Test that floating point precision errors in torch.arange do not return empty clips. + # Using 1/3 would cause arange to include sampling_range_end, which gets filtered out + # in _build_all_clips_timestamps, leaving clips with no frames. + # The fix rounds seconds_between_clip_starts to prevent this. + seconds_between_clip_starts = 1 / 3 + + decoder = VideoDecoder(H265_10BITS.path) # Video is 1 second long + # Set sampling range so that last clip will have frame timestamp ≈ end_stream_seconds + sampling_range_start = 0 + sampling_range_end = decoder.metadata.end_stream_seconds + seconds_between_frames = 1 + num_frames_per_clip = 1 + + clips = clips_at_regular_timestamps( + decoder, + seconds_between_clip_starts=seconds_between_clip_starts, + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + num_frames_per_clip=num_frames_per_clip, + seconds_between_frames=seconds_between_frames, + policy=policy, + ) + + # Ensure frame PTS can be decoded + for clip in clips: + frames = decoder.get_frames_played_at(seconds=clip.pts_seconds.tolist()) + assert isinstance(frames, FrameBatch) + assert frames.data.shape[0] == len(clip.pts_seconds) + assert len(clip.pts_seconds) == num_frames_per_clip From aa4c960907b5f5a3c36cde3e7cf6da3abc8c590d Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 18 Sep 2025 13:35:37 -0400 Subject: [PATCH 2/7] make rounding error reproducible --- test/test_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_samplers.py b/test/test_samplers.py index 36685bc46..10c529062 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -706,7 +706,7 @@ def test_floating_point_precision_in_clips_at_regular_timestamps(policy): # Using 1/3 would cause arange to include sampling_range_end, which gets filtered out # in _build_all_clips_timestamps, leaving clips with no frames. # The fix rounds seconds_between_clip_starts to prevent this. - seconds_between_clip_starts = 1 / 3 + seconds_between_clip_starts = 1 / 3 - 1e-9 decoder = VideoDecoder(H265_10BITS.path) # Video is 1 second long # Set sampling range so that last clip will have frame timestamp ≈ end_stream_seconds From b0e97722dad4bc3d6e919df100a04d94d099ec7b Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Thu, 18 Sep 2025 13:36:41 -0400 Subject: [PATCH 3/7] drop bad included index in range --- src/torchcodec/samplers/_time_based.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index e01e9d856..d11950970 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -196,11 +196,18 @@ def _generic_time_based_sampler( ) else: assert seconds_between_clip_starts is not None # appease type-checker + # The torch.arange documentation warns that floating point rounding errors + # are possible for non-integer steps when comparing to end. + # To prevent this, we check if the last clip_start_seconds value is + # equal to the end value, and remove it. + # docs.pytorch.org/docs/2.8/generated/torch.arange.html clip_start_seconds = torch.arange( sampling_range_start, sampling_range_end, # excluded - round(seconds_between_clip_starts, 6), + seconds_between_clip_starts, ) + if clip_start_seconds[-1] >= sampling_range_end: + clip_start_seconds = clip_start_seconds[:-1] num_clips = len(clip_start_seconds) all_clips_timestamps = _build_all_clips_timestamps( From 46a0b5b6a939de2e60a25f34e318f7d4d0ac3990 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Fri, 19 Sep 2025 11:08:04 -0400 Subject: [PATCH 4/7] subtract epsilon in torch.arange --- src/torchcodec/samplers/_time_based.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index d11950970..c35c89948 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -204,10 +204,8 @@ def _generic_time_based_sampler( clip_start_seconds = torch.arange( sampling_range_start, sampling_range_end, # excluded - seconds_between_clip_starts, + seconds_between_clip_starts - 1e-6, ) - if clip_start_seconds[-1] >= sampling_range_end: - clip_start_seconds = clip_start_seconds[:-1] num_clips = len(clip_start_seconds) all_clips_timestamps = _build_all_clips_timestamps( From 183a134857f1a862735c15ae1dcc3c2acbb30b28 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 22 Sep 2025 09:54:24 -0400 Subject: [PATCH 5/7] drop values >= end range --- src/torchcodec/samplers/_time_based.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index c35c89948..d11950970 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -204,8 +204,10 @@ def _generic_time_based_sampler( clip_start_seconds = torch.arange( sampling_range_start, sampling_range_end, # excluded - seconds_between_clip_starts - 1e-6, + seconds_between_clip_starts, ) + if clip_start_seconds[-1] >= sampling_range_end: + clip_start_seconds = clip_start_seconds[:-1] num_clips = len(clip_start_seconds) all_clips_timestamps = _build_all_clips_timestamps( From 4e467dd5982e7188fbbfb6f3f95e75ccf1e94e5b Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 22 Sep 2025 11:48:44 -0400 Subject: [PATCH 6/7] remove all values >= end --- src/torchcodec/samplers/_time_based.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index d11950970..70ef61d93 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -198,16 +198,20 @@ def _generic_time_based_sampler( assert seconds_between_clip_starts is not None # appease type-checker # The torch.arange documentation warns that floating point rounding errors # are possible for non-integer steps when comparing to end. - # To prevent this, we check if the last clip_start_seconds value is - # equal to the end value, and remove it. # docs.pytorch.org/docs/2.8/generated/torch.arange.html clip_start_seconds = torch.arange( sampling_range_start, sampling_range_end, # excluded seconds_between_clip_starts, ) + # As mentioned in the docs, torch.arange may return values + # equal to or above `end` because of floating precision errors. + # Here, we manually ensure all values are strictly lower than `sample_range_end` if clip_start_seconds[-1] >= sampling_range_end: - clip_start_seconds = clip_start_seconds[:-1] + clip_start_seconds = clip_start_seconds[ + clip_start_seconds < sampling_range_end + ] + num_clips = len(clip_start_seconds) all_clips_timestamps = _build_all_clips_timestamps( From 51f47c66c1ac190ee790588fa0069ce7e6b15e56 Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Mon, 22 Sep 2025 12:32:38 -0400 Subject: [PATCH 7/7] reduce comments --- src/torchcodec/samplers/_time_based.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index 70ef61d93..d58114121 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -196,9 +196,6 @@ def _generic_time_based_sampler( ) else: assert seconds_between_clip_starts is not None # appease type-checker - # The torch.arange documentation warns that floating point rounding errors - # are possible for non-integer steps when comparing to end. - # docs.pytorch.org/docs/2.8/generated/torch.arange.html clip_start_seconds = torch.arange( sampling_range_start, sampling_range_end, # excluded