Skip to content

Commit 6f906f4

Browse files
Dan-FloresDaniel Flores
andauthored
Fix sampling index error (#901)
Co-authored-by: Daniel Flores <danielflores3@fb.com>
1 parent fc60ed6 commit 6f906f4

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

src/torchcodec/samplers/_time_based.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,14 @@ def _generic_time_based_sampler(
201201
sampling_range_end, # excluded
202202
seconds_between_clip_starts,
203203
)
204+
# As mentioned in the docs, torch.arange may return values
205+
# equal to or above `end` because of floating precision errors.
206+
# Here, we manually ensure all values are strictly lower than `sample_range_end`
207+
if clip_start_seconds[-1] >= sampling_range_end:
208+
clip_start_seconds = clip_start_seconds[
209+
clip_start_seconds < sampling_range_end
210+
]
211+
204212
num_clips = len(clip_start_seconds)
205213

206214
all_clips_timestamps = _build_all_clips_timestamps(

test/test_samplers.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torchcodec.samplers._index_based import _build_all_clips_indices
2020
from torchcodec.samplers._time_based import _build_all_clips_timestamps
2121

22-
from .utils import assert_frames_equal, NASA_VIDEO
22+
from .utils import assert_frames_equal, H265_10BITS, NASA_VIDEO
2323

2424

2525
def _assert_output_type_and_shapes(
@@ -698,3 +698,36 @@ def test_build_all_clips_timestamps(
698698
assert all(isinstance(timestamp, float) for timestamp in all_clips_timestamps)
699699
assert len(all_clips_timestamps) == len(clip_start_seconds) * NUM_FRAMES_PER_CLIP
700700
assert all_clips_timestamps == expected_all_clips_timestamps
701+
702+
703+
@pytest.mark.parametrize("policy", ("repeat_last", "wrap", "error"))
704+
def test_floating_point_precision_in_clips_at_regular_timestamps(policy):
705+
# Test that floating point precision errors in torch.arange do not return empty clips.
706+
# Using 1/3 would cause arange to include sampling_range_end, which gets filtered out
707+
# in _build_all_clips_timestamps, leaving clips with no frames.
708+
# The fix rounds seconds_between_clip_starts to prevent this.
709+
seconds_between_clip_starts = 1 / 3 - 1e-9
710+
711+
decoder = VideoDecoder(H265_10BITS.path) # Video is 1 second long
712+
# Set sampling range so that last clip will have frame timestamp ≈ end_stream_seconds
713+
sampling_range_start = 0
714+
sampling_range_end = decoder.metadata.end_stream_seconds
715+
seconds_between_frames = 1
716+
num_frames_per_clip = 1
717+
718+
clips = clips_at_regular_timestamps(
719+
decoder,
720+
seconds_between_clip_starts=seconds_between_clip_starts,
721+
sampling_range_start=sampling_range_start,
722+
sampling_range_end=sampling_range_end,
723+
num_frames_per_clip=num_frames_per_clip,
724+
seconds_between_frames=seconds_between_frames,
725+
policy=policy,
726+
)
727+
728+
# Ensure frame PTS can be decoded
729+
for clip in clips:
730+
frames = decoder.get_frames_played_at(seconds=clip.pts_seconds.tolist())
731+
assert isinstance(frames, FrameBatch)
732+
assert frames.data.shape[0] == len(clip.pts_seconds)
733+
assert len(clip.pts_seconds) == num_frames_per_clip

0 commit comments

Comments
 (0)