@@ -4979,25 +4979,30 @@ def test_transform(self, make_input):
49794979 check_transform (transforms .CenterCrop (self .OUTPUT_SIZES [0 ]), make_input (self .INPUT_SIZE ))
49804980
49814981 @pytest .mark .parametrize ("output_size" , OUTPUT_SIZES )
4982+ @pytest .mark .parametrize (
4983+ "make_input" ,
4984+ [
4985+ make_image ,
4986+ pytest .param (
4987+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
4988+ ),
4989+ ],
4990+ )
49824991 @pytest .mark .parametrize ("fn" , [F .center_crop , transform_cls_to_functional (transforms .CenterCrop )])
4983- def test_image_correctness (self , output_size , fn ):
4984- image = make_image (self .INPUT_SIZE , dtype = torch .uint8 , device = "cpu" )
4992+ def test_image_correctness (self , output_size , make_input , fn ):
4993+ image = make_input (self .INPUT_SIZE , dtype = torch .uint8 , device = "cpu" )
49854994
49864995 actual = fn (image , output_size )
4987- expected = F .to_image (F .center_crop (F .to_pil_image (image ), output_size = output_size ))
49884996
4989- assert_equal (actual , expected )
4990-
4991- @pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
4992- @pytest .mark .parametrize ("output_size" , OUTPUT_SIZES )
4993- @pytest .mark .parametrize ("fn" , [F .center_crop , transform_cls_to_functional (transforms .CenterCrop )])
4994- def test_cvcuda_correctness (self , output_size , fn ):
4995- image = make_image_cvcuda (self .INPUT_SIZE , dtype = torch .uint8 , device = "cuda" )
4997+ if make_input == make_image_cvcuda :
4998+ actual = F .cvcuda_to_tensor (actual ).to (device = "cpu" )
4999+ actual = actual .squeeze (0 )
5000+ image = F .cvcuda_to_tensor (image ).to (device = "cpu" )
5001+ image = image .squeeze (0 )
49965002
4997- actual = fn (image , output_size )
4998- expected = F .center_crop (F .cvcuda_to_tensor (image ), output_size )
5003+ expected = F .to_image (F .center_crop (F .to_pil_image (image ), output_size = output_size ))
49995004
5000- assert_equal (F . cvcuda_to_tensor ( actual ) , expected )
5005+ assert_equal (actual , expected )
50015006
50025007 def _reference_center_crop_bounding_boxes (self , bounding_boxes , output_size ):
50035008 image_height , image_width = bounding_boxes .canvas_size
0 commit comments