Skip to content

Commit 97bf7cc

Browse files
authored
test: rename cpu_and_cuda to all_supported_devices (#817)
1 parent e0d1587 commit 97bf7cc

File tree

4 files changed

+44
-44
lines changed

4 files changed

+44
-44
lines changed

test/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def pytest_collection_modifyitems(items):
2323
# with the @needs_cuda decorator. It will also exist if it was
2424
# parametrized with a parameter that has the mark: for example if a test
2525
# is parametrized with
26-
# @pytest.mark.parametrize('device', cpu_and_cuda())
26+
# @pytest.mark.parametrize('device', all_supported_devices())
2727
# the "instances" of the tests where device == 'cuda' will have the
2828
# 'needs_cuda' mark, and the ones with device == 'cpu' won't have the
2929
# mark.

test/test_decoders.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
)
2323

2424
from .utils import (
25+
all_supported_devices,
2526
assert_frames_equal,
2627
AV1_VIDEO,
27-
cpu_and_cuda,
2828
get_ffmpeg_major_version,
2929
H264_10BITS,
3030
H265_10BITS,
@@ -163,7 +163,7 @@ def test_create_fails(self):
163163
VideoDecoder(NASA_VIDEO.path, seek_mode="blah")
164164

165165
@pytest.mark.parametrize("num_ffmpeg_threads", (1, 4))
166-
@pytest.mark.parametrize("device", cpu_and_cuda())
166+
@pytest.mark.parametrize("device", all_supported_devices())
167167
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
168168
def test_getitem_int(self, num_ffmpeg_threads, device, seek_mode):
169169
decoder = VideoDecoder(
@@ -213,7 +213,7 @@ def test_getitem_numpy_int(self):
213213
assert_frames_equal(ref_frame1, decoder[numpy.uint32(1)])
214214
assert_frames_equal(ref_frame180, decoder[numpy.uint32(180)])
215215

216-
@pytest.mark.parametrize("device", cpu_and_cuda())
216+
@pytest.mark.parametrize("device", all_supported_devices())
217217
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
218218
def test_getitem_slice(self, device, seek_mode):
219219
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
@@ -373,7 +373,7 @@ def test_device_instance(self):
373373
decoder = VideoDecoder(NASA_VIDEO.path, device=torch.device("cpu"))
374374
assert isinstance(decoder.metadata, VideoStreamMetadata)
375375

376-
@pytest.mark.parametrize("device", cpu_and_cuda())
376+
@pytest.mark.parametrize("device", all_supported_devices())
377377
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
378378
def test_getitem_fails(self, device, seek_mode):
379379
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
@@ -390,7 +390,7 @@ def test_getitem_fails(self, device, seek_mode):
390390
with pytest.raises(TypeError, match="Unsupported key type"):
391391
frame = decoder[2.3] # noqa
392392

393-
@pytest.mark.parametrize("device", cpu_and_cuda())
393+
@pytest.mark.parametrize("device", all_supported_devices())
394394
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
395395
def test_iteration(self, device, seek_mode):
396396
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
@@ -437,7 +437,7 @@ def test_iteration_slow(self):
437437

438438
assert iterations == len(decoder) == 390
439439

440-
@pytest.mark.parametrize("device", cpu_and_cuda())
440+
@pytest.mark.parametrize("device", all_supported_devices())
441441
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
442442
def test_get_frame_at(self, device, seek_mode):
443443
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
@@ -475,7 +475,7 @@ def test_get_frame_at(self, device, seek_mode):
475475
frame9 = decoder.get_frame_at(numpy.uint32(9))
476476
assert_frames_equal(ref_frame9, frame9.data)
477477

478-
@pytest.mark.parametrize("device", cpu_and_cuda())
478+
@pytest.mark.parametrize("device", all_supported_devices())
479479
def test_get_frame_at_tuple_unpacking(self, device):
480480
decoder = VideoDecoder(NASA_VIDEO.path, device=device)
481481

@@ -486,7 +486,7 @@ def test_get_frame_at_tuple_unpacking(self, device):
486486
assert frame.pts_seconds == pts
487487
assert frame.duration_seconds == duration
488488

489-
@pytest.mark.parametrize("device", cpu_and_cuda())
489+
@pytest.mark.parametrize("device", all_supported_devices())
490490
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
491491
def test_get_frame_at_fails(self, device, seek_mode):
492492
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
@@ -500,7 +500,7 @@ def test_get_frame_at_fails(self, device, seek_mode):
500500
with pytest.raises(IndexError, match="must be less than"):
501501
frame = decoder.get_frame_at(10000) # noqa
502502

503-
@pytest.mark.parametrize("device", cpu_and_cuda())
503+
@pytest.mark.parametrize("device", all_supported_devices())
504504
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
505505
def test_get_frames_at(self, device, seek_mode):
506506
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
@@ -551,7 +551,7 @@ def test_get_frames_at(self, device, seek_mode):
551551
frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0
552552
)
553553

554-
@pytest.mark.parametrize("device", cpu_and_cuda())
554+
@pytest.mark.parametrize("device", all_supported_devices())
555555
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
556556
def test_get_frames_at_fails(self, device, seek_mode):
557557
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
@@ -568,7 +568,7 @@ def test_get_frames_at_fails(self, device, seek_mode):
568568
with pytest.raises(RuntimeError, match="Expected a value of type"):
569569
decoder.get_frames_at([0.3])
570570

571-
@pytest.mark.parametrize("device", cpu_and_cuda())
571+
@pytest.mark.parametrize("device", all_supported_devices())
572572
def test_get_frame_at_av1(self, device):
573573
if device == "cuda" and get_ffmpeg_major_version() == 4:
574574
return
@@ -581,7 +581,7 @@ def test_get_frame_at_av1(self, device):
581581
assert decoded_frame10.pts_seconds == ref_frame_info10.pts_seconds
582582
assert_frames_equal(decoded_frame10.data, ref_frame10.to(device=device))
583583

584-
@pytest.mark.parametrize("device", cpu_and_cuda())
584+
@pytest.mark.parametrize("device", all_supported_devices())
585585
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
586586
def test_get_frame_played_at(self, device, seek_mode):
587587
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
@@ -610,7 +610,7 @@ def test_get_frame_played_at_h265(self):
610610
ref_frame6 = H265_VIDEO.get_frame_data_by_index(5)
611611
assert_frames_equal(ref_frame6, decoder.get_frame_played_at(0.5).data)
612612

613-
@pytest.mark.parametrize("device", cpu_and_cuda())
613+
@pytest.mark.parametrize("device", all_supported_devices())
614614
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
615615
def test_get_frame_played_at_fails(self, device, seek_mode):
616616
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
@@ -621,7 +621,7 @@ def test_get_frame_played_at_fails(self, device, seek_mode):
621621
with pytest.raises(IndexError, match="Invalid pts in seconds"):
622622
frame = decoder.get_frame_played_at(100.0) # noqa
623623

624-
@pytest.mark.parametrize("device", cpu_and_cuda())
624+
@pytest.mark.parametrize("device", all_supported_devices())
625625
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
626626
def test_get_frames_played_at(self, device, seek_mode):
627627

@@ -660,7 +660,7 @@ def test_get_frames_played_at(self, device, seek_mode):
660660
frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0
661661
)
662662

663-
@pytest.mark.parametrize("device", cpu_and_cuda())
663+
@pytest.mark.parametrize("device", all_supported_devices())
664664
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
665665
def test_get_frames_played_at_fails(self, device, seek_mode):
666666
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
@@ -674,7 +674,7 @@ def test_get_frames_played_at_fails(self, device, seek_mode):
674674
with pytest.raises(RuntimeError, match="Expected a value of type"):
675675
decoder.get_frames_played_at(["bad"])
676676

677-
@pytest.mark.parametrize("device", cpu_and_cuda())
677+
@pytest.mark.parametrize("device", all_supported_devices())
678678
@pytest.mark.parametrize("stream_index", [0, 3, None])
679679
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
680680
def test_get_frames_in_range(self, stream_index, device, seek_mode):
@@ -779,7 +779,7 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
779779
empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds
780780
)
781781

782-
@pytest.mark.parametrize("device", cpu_and_cuda())
782+
@pytest.mark.parametrize("device", all_supported_devices())
783783
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
784784
def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode):
785785
decoder = VideoDecoder(
@@ -831,7 +831,7 @@ def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode):
831831
).to(device)
832832
assert_frames_equal(frames387_None.data, reference_frame387_389)
833833

834-
@pytest.mark.parametrize("device", cpu_and_cuda())
834+
@pytest.mark.parametrize("device", all_supported_devices())
835835
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
836836
@patch("torchcodec._core._metadata._get_stream_json_metadata")
837837
def test_get_frames_with_missing_num_frames_metadata(
@@ -894,7 +894,7 @@ def test_get_frames_with_missing_num_frames_metadata(
894894
lambda decoder: decoder.get_frames_played_in_range(0, 1).data,
895895
),
896896
)
897-
@pytest.mark.parametrize("device", cpu_and_cuda())
897+
@pytest.mark.parametrize("device", all_supported_devices())
898898
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
899899
def test_dimension_order(self, dimension_order, frame_getter, device, seek_mode):
900900
decoder = VideoDecoder(
@@ -922,7 +922,7 @@ def test_dimension_order_fails(self):
922922
VideoDecoder(NASA_VIDEO.path, dimension_order="NCDHW")
923923

924924
@pytest.mark.parametrize("stream_index", [0, 3, None])
925-
@pytest.mark.parametrize("device", cpu_and_cuda())
925+
@pytest.mark.parametrize("device", all_supported_devices())
926926
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
927927
def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode):
928928
decoder = VideoDecoder(
@@ -1061,7 +1061,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode):
10611061
)
10621062
assert_frames_equal(all_frames.data, decoder[:])
10631063

1064-
@pytest.mark.parametrize("device", cpu_and_cuda())
1064+
@pytest.mark.parametrize("device", all_supported_devices())
10651065
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
10661066
def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
10671067
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
@@ -1075,7 +1075,7 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
10751075
with pytest.raises(ValueError, match="Invalid stop seconds"):
10761076
frame = decoder.get_frames_played_in_range(0, 23) # noqa
10771077

1078-
@pytest.mark.parametrize("device", cpu_and_cuda())
1078+
@pytest.mark.parametrize("device", all_supported_devices())
10791079
def test_get_key_frame_indices(self, device):
10801080
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact")
10811081
key_frame_indices = decoder._get_key_frame_indices()
@@ -1120,7 +1120,7 @@ def test_get_key_frame_indices(self, device):
11201120

11211121
# TODO investigate why this fails internally.
11221122
@pytest.mark.skipif(in_fbcode(), reason="Compile test fails internally.")
1123-
@pytest.mark.parametrize("device", cpu_and_cuda())
1123+
@pytest.mark.parametrize("device", all_supported_devices())
11241124
def test_compile(self, device):
11251125
decoder = VideoDecoder(NASA_VIDEO.path, device=device)
11261126

0 commit comments

Comments
 (0)