|
25 | 25 | assert_equal, |
26 | 26 | cache, |
27 | 27 | cpu_and_cuda, |
| 28 | + cvcuda_to_pil_compatible_tensor, |
28 | 29 | freeze_rng_state, |
29 | 30 | ignore_jit_no_profile_information_warning, |
30 | 31 | make_bounding_boxes, |
@@ -5821,23 +5822,26 @@ def test_functional_signature(self, kernel, input_type): |
5821 | 5822 | def test_transform(self, make_input): |
5822 | 5823 | check_transform(transforms.RandomInvert(p=1), make_input()) |
5823 | 5824 |
|
| 5825 | + @pytest.mark.parametrize( |
| 5826 | + "make_input", |
| 5827 | + [ |
| 5828 | + make_image, |
| 5829 | + pytest.param( |
| 5830 | + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") |
| 5831 | + ), |
| 5832 | + ], |
| 5833 | + ) |
5824 | 5834 | @pytest.mark.parametrize("fn", [F.invert, transform_cls_to_functional(transforms.RandomInvert, p=1)]) |
5825 | | - def test_correctness_image(self, fn): |
5826 | | - image = make_image(dtype=torch.uint8, device="cpu") |
| 5835 | + def test_correctness_image(self, make_input, fn): |
| 5836 | + image = make_input(dtype=torch.uint8, device="cpu") |
5827 | 5837 |
|
5828 | 5838 | actual = fn(image) |
5829 | | - expected = F.to_image(F.invert(F.to_pil_image(image))) |
5830 | 5839 |
|
5831 | | - assert_equal(actual, expected) |
| 5840 | + if make_input is make_image_cvcuda: |
| 5841 | + image = cvcuda_to_pil_compatible_tensor(image) |
| 5842 | + |
| 5843 | + expected = F.to_image(F.invert(F.to_pil_image(image))) |
5832 | 5844 |
|
5833 | | - @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") |
5834 | | - @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) |
5835 | | - @pytest.mark.parametrize("fn", [F.invert, transform_cls_to_functional(transforms.RandomInvert, p=1)]) |
5836 | | - def test_correctness_cvcuda(self, dtype, fn): |
5837 | | - image = make_image(batch_dims=(1,), dtype=dtype, device="cuda") |
5838 | | - cv_image = F.to_cvcuda_tensor(image) |
5839 | | - actual = F.cvcuda_to_tensor(fn(cv_image)) |
5840 | | - expected = F.invert_image(image) |
5841 | 5845 | assert_equal(actual, expected) |
5842 | 5846 |
|
5843 | 5847 |
|
|
0 commit comments