@@ -6277,7 +6277,18 @@ def test_kernel_image(self, dtype, device):
62776277 def test_kernel_video (self ):
62786278 check_kernel (F .adjust_saturation_video , make_video (), saturation_factor = 0.5 )
62796279
6280- @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image , make_image_pil , make_video ])
6280+ @pytest .mark .parametrize (
6281+ "make_input" ,
6282+ [
6283+ make_image_tensor ,
6284+ make_image ,
6285+ make_image_pil ,
6286+ make_video ,
6287+ pytest .param (
6288+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
6289+ ),
6290+ ],
6291+ )
62816292 def test_functional (self , make_input ):
62826293 check_functional (F .adjust_saturation , make_input (), saturation_factor = 0.5 )
62836294
@@ -6288,9 +6299,16 @@ def test_functional(self, make_input):
62886299 (F ._color ._adjust_saturation_image_pil , PIL .Image .Image ),
62896300 (F .adjust_saturation_image , tv_tensors .Image ),
62906301 (F .adjust_saturation_video , tv_tensors .Video ),
6302+ pytest .param (
6303+ F ._color ._adjust_saturation_cvcuda ,
6304+ "cvcuda.Tensor" ,
6305+ marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" ),
6306+ ),
62916307 ],
62926308 )
62936309 def test_functional_signature (self , kernel , input_type ):
6310+ if input_type == "cvcuda.Tensor" :
6311+ input_type = _import_cvcuda ().Tensor
62946312 check_functional_kernel_signature_match (F .adjust_saturation , kernel = kernel , input_type = input_type )
62956313
62966314 def test_functional_error (self ):
@@ -6300,11 +6318,28 @@ def test_functional_error(self):
63006318 with pytest .raises (ValueError , match = "is not non-negative" ):
63016319 F .adjust_saturation (make_image (), saturation_factor = - 1 )
63026320
6321+ @pytest .mark .parametrize (
6322+ "make_input" ,
6323+ [
6324+ make_image ,
6325+ pytest .param (
6326+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
6327+ ),
6328+ ],
6329+ )
6330+ @pytest .mark .parametrize ("color_space" , ["RGB" , "GRAY" ])
63036331 @pytest .mark .parametrize ("saturation_factor" , [0.1 , 0.5 , 1.0 ])
6304- def test_correctness_image (self , saturation_factor ):
6305- image = make_image (dtype = torch .uint8 , device = "cpu" )
6332+ def test_correctness_image (self , make_input , color_space , saturation_factor ):
6333+ image = make_input (dtype = torch .uint8 , color_space = color_space , device = "cpu" )
63066334
63076335 actual = F .adjust_saturation (image , saturation_factor = saturation_factor )
6336+
6337+ if make_input is make_image_cvcuda :
6338+ actual = F .cvcuda_to_tensor (actual ).to (device = "cpu" )
6339+ actual = actual .squeeze (0 )
6340+ image = F .cvcuda_to_tensor (image )
6341+ image = image .squeeze (0 )
6342+
63086343 expected = F .to_image (F .adjust_saturation (F .to_pil_image (image ), saturation_factor = saturation_factor ))
63096344
63106345 assert_close (actual , expected , rtol = 0 , atol = 1 )
0 commit comments