22
22
)
23
23
24
24
from .utils import (
25
+ all_supported_devices ,
25
26
assert_frames_equal ,
26
27
AV1_VIDEO ,
27
- cpu_and_cuda ,
28
28
get_ffmpeg_major_version ,
29
29
H264_10BITS ,
30
30
H265_10BITS ,
@@ -163,7 +163,7 @@ def test_create_fails(self):
163
163
VideoDecoder (NASA_VIDEO .path , seek_mode = "blah" )
164
164
165
165
@pytest .mark .parametrize ("num_ffmpeg_threads" , (1 , 4 ))
166
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
166
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
167
167
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
168
168
def test_getitem_int (self , num_ffmpeg_threads , device , seek_mode ):
169
169
decoder = VideoDecoder (
@@ -213,7 +213,7 @@ def test_getitem_numpy_int(self):
213
213
assert_frames_equal (ref_frame1 , decoder [numpy .uint32 (1 )])
214
214
assert_frames_equal (ref_frame180 , decoder [numpy .uint32 (180 )])
215
215
216
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
216
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
217
217
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
218
218
def test_getitem_slice (self , device , seek_mode ):
219
219
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -373,7 +373,7 @@ def test_device_instance(self):
373
373
decoder = VideoDecoder (NASA_VIDEO .path , device = torch .device ("cpu" ))
374
374
assert isinstance (decoder .metadata , VideoStreamMetadata )
375
375
376
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
376
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
377
377
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
378
378
def test_getitem_fails (self , device , seek_mode ):
379
379
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -390,7 +390,7 @@ def test_getitem_fails(self, device, seek_mode):
390
390
with pytest .raises (TypeError , match = "Unsupported key type" ):
391
391
frame = decoder [2.3 ] # noqa
392
392
393
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
393
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
394
394
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
395
395
def test_iteration (self , device , seek_mode ):
396
396
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -437,7 +437,7 @@ def test_iteration_slow(self):
437
437
438
438
assert iterations == len (decoder ) == 390
439
439
440
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
440
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
441
441
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
442
442
def test_get_frame_at (self , device , seek_mode ):
443
443
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -475,7 +475,7 @@ def test_get_frame_at(self, device, seek_mode):
475
475
frame9 = decoder .get_frame_at (numpy .uint32 (9 ))
476
476
assert_frames_equal (ref_frame9 , frame9 .data )
477
477
478
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
478
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
479
479
def test_get_frame_at_tuple_unpacking (self , device ):
480
480
decoder = VideoDecoder (NASA_VIDEO .path , device = device )
481
481
@@ -486,7 +486,7 @@ def test_get_frame_at_tuple_unpacking(self, device):
486
486
assert frame .pts_seconds == pts
487
487
assert frame .duration_seconds == duration
488
488
489
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
489
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
490
490
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
491
491
def test_get_frame_at_fails (self , device , seek_mode ):
492
492
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -500,7 +500,7 @@ def test_get_frame_at_fails(self, device, seek_mode):
500
500
with pytest .raises (IndexError , match = "must be less than" ):
501
501
frame = decoder .get_frame_at (10000 ) # noqa
502
502
503
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
503
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
504
504
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
505
505
def test_get_frames_at (self , device , seek_mode ):
506
506
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -551,7 +551,7 @@ def test_get_frames_at(self, device, seek_mode):
551
551
frames .duration_seconds , expected_duration_seconds , atol = 1e-4 , rtol = 0
552
552
)
553
553
554
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
554
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
555
555
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
556
556
def test_get_frames_at_fails (self , device , seek_mode ):
557
557
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -568,7 +568,7 @@ def test_get_frames_at_fails(self, device, seek_mode):
568
568
with pytest .raises (RuntimeError , match = "Expected a value of type" ):
569
569
decoder .get_frames_at ([0.3 ])
570
570
571
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
571
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
572
572
def test_get_frame_at_av1 (self , device ):
573
573
if device == "cuda" and get_ffmpeg_major_version () == 4 :
574
574
return
@@ -581,7 +581,7 @@ def test_get_frame_at_av1(self, device):
581
581
assert decoded_frame10 .pts_seconds == ref_frame_info10 .pts_seconds
582
582
assert_frames_equal (decoded_frame10 .data , ref_frame10 .to (device = device ))
583
583
584
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
584
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
585
585
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
586
586
def test_get_frame_played_at (self , device , seek_mode ):
587
587
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -610,7 +610,7 @@ def test_get_frame_played_at_h265(self):
610
610
ref_frame6 = H265_VIDEO .get_frame_data_by_index (5 )
611
611
assert_frames_equal (ref_frame6 , decoder .get_frame_played_at (0.5 ).data )
612
612
613
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
613
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
614
614
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
615
615
def test_get_frame_played_at_fails (self , device , seek_mode ):
616
616
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -621,7 +621,7 @@ def test_get_frame_played_at_fails(self, device, seek_mode):
621
621
with pytest .raises (IndexError , match = "Invalid pts in seconds" ):
622
622
frame = decoder .get_frame_played_at (100.0 ) # noqa
623
623
624
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
624
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
625
625
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
626
626
def test_get_frames_played_at (self , device , seek_mode ):
627
627
@@ -660,7 +660,7 @@ def test_get_frames_played_at(self, device, seek_mode):
660
660
frames .duration_seconds , expected_duration_seconds , atol = 1e-4 , rtol = 0
661
661
)
662
662
663
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
663
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
664
664
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
665
665
def test_get_frames_played_at_fails (self , device , seek_mode ):
666
666
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -674,7 +674,7 @@ def test_get_frames_played_at_fails(self, device, seek_mode):
674
674
with pytest .raises (RuntimeError , match = "Expected a value of type" ):
675
675
decoder .get_frames_played_at (["bad" ])
676
676
677
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
677
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
678
678
@pytest .mark .parametrize ("stream_index" , [0 , 3 , None ])
679
679
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
680
680
def test_get_frames_in_range (self , stream_index , device , seek_mode ):
@@ -779,7 +779,7 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
779
779
empty_frames .duration_seconds , NASA_VIDEO .empty_duration_seconds
780
780
)
781
781
782
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
782
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
783
783
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
784
784
def test_get_frames_in_range_slice_indices_syntax (self , device , seek_mode ):
785
785
decoder = VideoDecoder (
@@ -831,7 +831,7 @@ def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode):
831
831
).to (device )
832
832
assert_frames_equal (frames387_None .data , reference_frame387_389 )
833
833
834
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
834
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
835
835
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
836
836
@patch ("torchcodec._core._metadata._get_stream_json_metadata" )
837
837
def test_get_frames_with_missing_num_frames_metadata (
@@ -894,7 +894,7 @@ def test_get_frames_with_missing_num_frames_metadata(
894
894
lambda decoder : decoder .get_frames_played_in_range (0 , 1 ).data ,
895
895
),
896
896
)
897
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
897
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
898
898
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
899
899
def test_dimension_order (self , dimension_order , frame_getter , device , seek_mode ):
900
900
decoder = VideoDecoder (
@@ -922,7 +922,7 @@ def test_dimension_order_fails(self):
922
922
VideoDecoder (NASA_VIDEO .path , dimension_order = "NCDHW" )
923
923
924
924
@pytest .mark .parametrize ("stream_index" , [0 , 3 , None ])
925
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
925
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
926
926
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
927
927
def test_get_frames_by_pts_in_range (self , stream_index , device , seek_mode ):
928
928
decoder = VideoDecoder (
@@ -1061,7 +1061,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode):
1061
1061
)
1062
1062
assert_frames_equal (all_frames .data , decoder [:])
1063
1063
1064
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1064
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
1065
1065
@pytest .mark .parametrize ("seek_mode" , ("exact" , "approximate" ))
1066
1066
def test_get_frames_by_pts_in_range_fails (self , device , seek_mode ):
1067
1067
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = seek_mode )
@@ -1075,7 +1075,7 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
1075
1075
with pytest .raises (ValueError , match = "Invalid stop seconds" ):
1076
1076
frame = decoder .get_frames_played_in_range (0 , 23 ) # noqa
1077
1077
1078
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1078
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
1079
1079
def test_get_key_frame_indices (self , device ):
1080
1080
decoder = VideoDecoder (NASA_VIDEO .path , device = device , seek_mode = "exact" )
1081
1081
key_frame_indices = decoder ._get_key_frame_indices ()
@@ -1120,7 +1120,7 @@ def test_get_key_frame_indices(self, device):
1120
1120
1121
1121
# TODO investigate why this fails internally.
1122
1122
@pytest .mark .skipif (in_fbcode (), reason = "Compile test fails internally." )
1123
- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1123
+ @pytest .mark .parametrize ("device" , all_supported_devices ())
1124
1124
def test_compile (self , device ):
1125
1125
decoder = VideoDecoder (NASA_VIDEO .path , device = device )
1126
1126
0 commit comments