@@ -5650,24 +5650,36 @@ def _reference_normalize_image(self, image, *, mean, std):
56505650
56515651 @pytest .mark .parametrize (("mean" , "std" ), MEANS_STDS )
56525652 @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .float32 , torch .float64 ])
5653+ @pytest .mark .parametrize (
5654+ "make_input" ,
5655+ [
5656+ make_image ,
5657+ pytest .param (
5658+ make_image_cvcuda , marks = pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
5659+ ),
5660+ ],
5661+ )
56535662 @pytest .mark .parametrize ("fn" , [F .normalize , transform_cls_to_functional (transforms .Normalize )])
5654- def test_correctness_image (self , mean , std , dtype , fn ):
5655- image = make_image (dtype = dtype )
5663+ def test_correctness_image (self , mean , std , dtype , make_input , fn ):
5664+ if make_input == make_image_cvcuda and dtype != torch .float32 :
5665+ pytest .skip ("CVCUDA only supports float32 for normalize" )
5666+
5667+ image = make_input (dtype = dtype )
56565668
56575669 actual = fn (image , mean = mean , std = std )
5658- expected = self ._reference_normalize_image (image , mean = mean , std = std )
56595670
5660- assert_equal (actual , expected )
5671+ if make_input == make_image_cvcuda :
5672+ image = F .cvcuda_to_tensor (image ).to (device = "cpu" )
5673+ image = image .squeeze (0 )
5674+ actual = F .cvcuda_to_tensor (actual ).to (device = "cpu" )
5675+ actual = actual .squeeze (0 )
56615676
5662- @pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "test requires CVCUDA" )
5663- @pytest .mark .parametrize (("mean" , "std" ), MEANS_STDS )
5664- @pytest .mark .parametrize ("fn" , [F .normalize , transform_cls_to_functional (transforms .Normalize )])
5665- def test_correctness_cvcuda (self , mean , std , fn ):
5666- image = make_image (batch_dims = (1 ,), dtype = torch .float32 , device = "cuda" )
5667- cvc_image = F .to_cvcuda_tensor (image )
5668- actual = F ._misc ._normalize_cvcuda (cvc_image , mean = mean , std = std )
5669- expected = fn (image , mean = mean , std = std )
5670- torch .testing .assert_close (F .cvcuda_to_tensor (actual ), expected , rtol = 1e-7 , atol = 1e-7 )
5677+ expected = self ._reference_normalize_image (image , mean = mean , std = std )
5678+
5679+ if make_input == make_image_cvcuda :
5680+ torch .testing .assert_close (actual , expected , rtol = 0 , atol = 1e-6 )
5681+ else :
5682+ assert_equal (actual , expected )
56715683
56725684
56735685class TestClampBoundingBoxes :
0 commit comments