@@ -5772,7 +5772,18 @@ def test_kernel_image(self, dtype, device):
57725772 def test_kernel_video (self ):
57735773 check_kernel (F .invert_video , make_video ())
57745774
5775- @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image , make_image_pil , make_video ])
5775+ @pytest .mark .parametrize (
5776+ "make_input" ,
5777+ [
5778+ make_image_tensor ,
5779+ make_image ,
5780+ make_image_pil ,
5781+ make_video ,
5782+ pytest .param (
5783+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
5784+ ),
5785+ ],
5786+ )
57765787 def test_functional (self , make_input ):
57775788 check_functional (F .invert , make_input ())
57785789
@@ -5783,12 +5794,30 @@ def test_functional(self, make_input):
57835794 (F ._color ._invert_image_pil , PIL .Image .Image ),
57845795 (F .invert_image , tv_tensors .Image ),
57855796 (F .invert_video , tv_tensors .Video ),
5797+ pytest .param (
5798+ F ._color ._invert_cvcuda ,
5799+ "cvcuda.Tensor" ,
5800+ marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" ),
5801+ ),
57865802 ],
57875803 )
57885804 def test_functional_signature (self , kernel , input_type ):
5805+ if input_type == "cvcuda.Tensor" :
5806+ input_type = _import_cvcuda ().Tensor
57895807 check_functional_kernel_signature_match (F .invert , kernel = kernel , input_type = input_type )
57905808
5791- @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image_pil , make_image , make_video ])
5809+ @pytest .mark .parametrize (
5810+ "make_input" ,
5811+ [
5812+ make_image_tensor ,
5813+ make_image_pil ,
5814+ make_image ,
5815+ make_video ,
5816+ pytest .param (
5817+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
5818+ ),
5819+ ],
5820+ )
57925821 def test_transform (self , make_input ):
57935822 check_transform (transforms .RandomInvert (p = 1 ), make_input ())
57945823
@@ -5801,6 +5830,16 @@ def test_correctness_image(self, fn):
58015830
58025831 assert_equal (actual , expected )
58035832
5833+ @pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
5834+ @pytest .mark .parametrize ("dtype" , [torch .uint8 , torch .float32 ])
5835+ @pytest .mark .parametrize ("fn" , [F .invert , transform_cls_to_functional (transforms .RandomInvert , p = 1 )])
5836+ def test_correctness_cvcuda (self , dtype , fn ):
5837+ image = make_image (batch_dims = (1 ,), dtype = dtype , device = "cuda" )
5838+ cv_image = F .to_cvcuda_tensor (image )
5839+ actual = F .cvcuda_to_tensor (fn (cv_image ))
5840+ expected = F .invert_image (image )
5841+ assert_equal (actual , expected )
5842+
58045843
58055844class TestPosterize :
58065845 @pytest .mark .parametrize ("dtype" , [torch .uint8 , torch .float32 ])
0 commit comments