@@ -5129,6 +5129,9 @@ def test_kernel_video(self):
51295129 make_segmentation_mask ,
51305130 make_video ,
51315131 make_keypoints ,
5132+ pytest .param (
5133+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
5134+ ),
51325135 ],
51335136 )
51345137 def test_functional (self , make_input ):
@@ -5144,9 +5147,16 @@ def test_functional(self, make_input):
51445147 (F .perspective_mask , tv_tensors .Mask ),
51455148 (F .perspective_video , tv_tensors .Video ),
51465149 (F .perspective_keypoints , tv_tensors .KeyPoints ),
5150+ pytest .param (
5151+ F ._geometry ._perspective_cvcuda ,
5152+ "cvcuda.Tensor" ,
5153+ marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" ),
5154+ ),
51475155 ],
51485156 )
51495157 def test_functional_signature (self , kernel , input_type ):
5158+ if input_type == "cvcuda.Tensor" :
5159+ input_type = _import_cvcuda ().Tensor
51505160 check_functional_kernel_signature_match (F .perspective , kernel = kernel , input_type = input_type )
51515161
51525162 @pytest .mark .parametrize ("distortion_scale" , [0.5 , 0.0 , 1.0 ])
@@ -5160,6 +5170,9 @@ def test_functional_signature(self, kernel, input_type):
51605170 make_segmentation_mask ,
51615171 make_video ,
51625172 make_keypoints ,
5173+ pytest .param (
5174+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
5175+ ),
51635176 ],
51645177 )
51655178 def test_transform (self , distortion_scale , make_input ):
@@ -5175,12 +5188,28 @@ def test_transform_error(self, distortion_scale):
51755188 "interpolation" , [transforms .InterpolationMode .NEAREST , transforms .InterpolationMode .BILINEAR ]
51765189 )
51775190 @pytest .mark .parametrize ("fill" , CORRECTNESS_FILLS )
5178- def test_image_functional_correctness (self , coefficients , interpolation , fill ):
5179- image = make_image (dtype = torch .uint8 , device = "cpu" )
5191+ @pytest .mark .parametrize (
5192+ "make_input" ,
5193+ [
5194+ make_image ,
5195+ pytest .param (
5196+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "CVCUDA not available" )
5197+ ),
5198+ ],
5199+ )
5200+ def test_image_functional_correctness (self , coefficients , interpolation , fill , make_input ):
5201+ image = make_input (dtype = torch .uint8 , device = "cpu" )
51805202
51815203 actual = F .perspective (
51825204 image , startpoints = None , endpoints = None , coefficients = coefficients , interpolation = interpolation , fill = fill
51835205 )
5206+ if make_input is make_image_cvcuda :
5207+ actual = F .cvcuda_to_tensor (actual ).to (device = "cpu" )
5208+ actual = actual .squeeze (0 )
5209+ # drop the batch dimension
5210+ image = F .cvcuda_to_tensor (image ).to (device = "cpu" )
5211+ image = image .squeeze (0 )
5212+
51845213 expected = F .to_image (
51855214 F .perspective (
51865215 F .to_pil_image (image ),
@@ -5192,13 +5221,20 @@ def test_image_functional_correctness(self, coefficients, interpolation, fill):
51925221 )
51935222 )
51945223
5195- if interpolation is transforms .InterpolationMode .BILINEAR :
5196- abs_diff = (actual .float () - expected .float ()).abs ()
5197- assert (abs_diff > 1 ).float ().mean () < 7e-2
5198- mae = abs_diff .mean ()
5199- assert mae < 3
5200- else :
5201- assert_equal (actual , expected )
5224+ if make_input is make_image :
5225+ if interpolation is transforms .InterpolationMode .BILINEAR :
5226+ abs_diff = (actual .float () - expected .float ()).abs ()
5227+ assert (abs_diff > 1 ).float ().mean () < 7e-2
5228+ mae = abs_diff .mean ()
5229+ assert mae < 3
5230+ else :
5231+ assert_equal (actual , expected )
5232+ else : # CV-CUDA
5233+ # just check that the shapes/dtypes are the same, cvcuda warp_perspective uses different algorithm
5234+ # visually the results are the same on real images,
5235+ # realistically, the diff is not visible to the human eye
5236+ tolerance = 255 if interpolation is transforms .InterpolationMode .NEAREST else 125
5237+ torch .testing .assert_close (actual , expected , rtol = 0 , atol = tolerance )
52025238
52035239 def _reference_perspective_bounding_boxes (self , bounding_boxes , * , startpoints , endpoints ):
52045240 format = bounding_boxes .format
0 commit comments