|
19 | 19 | from torchcodec.samplers._index_based import _build_all_clips_indices
|
20 | 20 | from torchcodec.samplers._time_based import _build_all_clips_timestamps
|
21 | 21 |
|
22 |
| -from .utils import assert_frames_equal, NASA_VIDEO |
| 22 | +from .utils import assert_frames_equal, H265_10BITS, NASA_VIDEO |
23 | 23 |
|
24 | 24 |
|
25 | 25 | def _assert_output_type_and_shapes(
|
@@ -698,3 +698,36 @@ def test_build_all_clips_timestamps(
|
698 | 698 | assert all(isinstance(timestamp, float) for timestamp in all_clips_timestamps)
|
699 | 699 | assert len(all_clips_timestamps) == len(clip_start_seconds) * NUM_FRAMES_PER_CLIP
|
700 | 700 | assert all_clips_timestamps == expected_all_clips_timestamps
|
| 701 | + |
| 702 | + |
| 703 | +@pytest.mark.parametrize("policy", ("repeat_last", "wrap", "error")) |
| 704 | +def test_floating_point_precision_in_clips_at_regular_timestamps(policy): |
| 705 | + # Test that floating point precision errors in torch.arange do not return empty clips. |
| 706 | + # Using 1/3 would cause arange to include sampling_range_end, which gets filtered out |
| 707 | + # in _build_all_clips_timestamps, leaving clips with no frames. |
| 708 | + # The fix rounds seconds_between_clip_starts to prevent this. |
| 709 | + seconds_between_clip_starts = 1 / 3 - 1e-9 |
| 710 | + |
| 711 | + decoder = VideoDecoder(H265_10BITS.path) # Video is 1 second long |
| 712 | + # Set sampling range so that last clip will have frame timestamp ≈ end_stream_seconds |
| 713 | + sampling_range_start = 0 |
| 714 | + sampling_range_end = decoder.metadata.end_stream_seconds |
| 715 | + seconds_between_frames = 1 |
| 716 | + num_frames_per_clip = 1 |
| 717 | + |
| 718 | + clips = clips_at_regular_timestamps( |
| 719 | + decoder, |
| 720 | + seconds_between_clip_starts=seconds_between_clip_starts, |
| 721 | + sampling_range_start=sampling_range_start, |
| 722 | + sampling_range_end=sampling_range_end, |
| 723 | + num_frames_per_clip=num_frames_per_clip, |
| 724 | + seconds_between_frames=seconds_between_frames, |
| 725 | + policy=policy, |
| 726 | + ) |
| 727 | + |
| 728 | + # Ensure frame PTS can be decoded |
| 729 | + for clip in clips: |
| 730 | + frames = decoder.get_frames_played_at(seconds=clip.pts_seconds.tolist()) |
| 731 | + assert isinstance(frames, FrameBatch) |
| 732 | + assert frames.data.shape[0] == len(clip.pts_seconds) |
| 733 | + assert len(clip.pts_seconds) == num_frames_per_clip |
0 commit comments