@@ -5398,7 +5398,18 @@ def test_kernel_image(self, dtype, device):
53985398 def test_kernel_video (self ):
53995399 check_kernel (F .equalize_image , make_video ())
54005400
5401- @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image_pil , make_image , make_video ])
5401+ @pytest .mark .parametrize (
5402+ "make_input" ,
5403+ [
5404+ make_image_tensor ,
5405+ make_image_pil ,
5406+ make_image ,
5407+ make_video ,
5408+ pytest .param (
5409+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
5410+ ),
5411+ ],
5412+ )
54025413 def test_functional (self , make_input ):
54035414 check_functional (F .equalize , make_input ())
54045415
@@ -5409,33 +5420,71 @@ def test_functional(self, make_input):
54095420 (F ._color ._equalize_image_pil , PIL .Image .Image ),
54105421 (F .equalize_image , tv_tensors .Image ),
54115422 (F .equalize_video , tv_tensors .Video ),
5423+ pytest .param (
5424+ F ._color ._equalize_cvcuda ,
5425+ "cvcuda.Tensor" ,
5426+ marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" ),
5427+ ),
54125428 ],
54135429 )
54145430 def test_functional_signature (self , kernel , input_type ):
5431+ if input_type == "cvcuda.Tensor" :
5432+ input_type = _import_cvcuda ().Tensor
54155433 check_functional_kernel_signature_match (F .equalize , kernel = kernel , input_type = input_type )
54165434
54175435 @pytest .mark .parametrize (
54185436 "make_input" ,
5419- [make_image_tensor , make_image_pil , make_image , make_video ],
5437+ [
5438+ make_image_tensor ,
5439+ make_image_pil ,
5440+ make_image ,
5441+ make_video ,
5442+ pytest .param (
5443+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
5444+ ),
5445+ ],
54205446 )
54215447 def test_transform (self , make_input ):
54225448 check_transform (transforms .RandomEqualize (p = 1 ), make_input ())
54235449
54245450 @pytest .mark .parametrize (("low" , "high" ), [(0 , 64 ), (64 , 192 ), (192 , 256 ), (0 , 1 ), (127 , 128 ), (255 , 256 )])
5451+ @pytest .mark .parametrize (
5452+ "tensor_type" ,
5453+ [
5454+ torch .Tensor ,
5455+ pytest .param (
5456+ "cvcuda.Tensor" , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
5457+ ),
5458+ ],
5459+ )
54255460 @pytest .mark .parametrize ("fn" , [F .equalize , transform_cls_to_functional (transforms .RandomEqualize , p = 1 )])
5426- def test_image_correctness (self , low , high , fn ):
5461+ def test_image_correctness (self , low , high , tensor_type , fn ):
54275462 # We are not using the default `make_image` here since that uniformly samples the values over the whole value
54285463 # range. Since the whole point of F.equalize is to transform an arbitrary distribution of values into a uniform
54295464 # one over the full range, the information gain is low if we already provide something really close to the
54305465 # expected value.
5431- image = tv_tensors .Image (
5432- torch .testing .make_tensor ((3 , 117 , 253 ), dtype = torch .uint8 , device = "cpu" , low = low , high = high )
5433- )
5466+ shape = (3 , 117 , 253 )
5467+ if tensor_type == "cvcuda.Tensor" :
5468+ shape = (1 , * shape )
5469+ image = tv_tensors .Image (torch .testing .make_tensor (shape , dtype = torch .uint8 , device = "cpu" , low = low , high = high ))
5470+
5471+ if tensor_type == "cvcuda.Tensor" :
5472+ image = F .to_cvcuda_tensor (image )
54345473
54355474 actual = fn (image )
5475+
5476+ if tensor_type == "cvcuda.Tensor" :
5477+ actual = F .cvcuda_to_tensor (actual ).to (device = "cpu" )
5478+ actual = actual .squeeze (0 )
5479+ image = F .cvcuda_to_tensor (image )
5480+ image = image .squeeze (0 )
5481+
54365482 expected = F .to_image (F .equalize (F .to_pil_image (image )))
54375483
5438- assert_equal (actual , expected )
5484+ if tensor_type == "cvcuda.Tensor" :
5485+ torch .testing .assert_close (actual , expected , rtol = 1e-10 , atol = 1 )
5486+ else :
5487+ assert_equal (actual , expected )
54395488
54405489
54415490class TestUniformTemporalSubsample :
0 commit comments