@@ -2152,7 +2152,8 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand
21522152
21532153 actual = F .rotate (image , angle = angle , center = center , interpolation = interpolation , expand = expand , fill = fill )
21542154
2155- if make_input == make_image_cvcuda :
2155+ if make_input is make_image_cvcuda :
2156+ actual = cvcuda_to_pil_compatible_tensor (actual )
21562157 image = cvcuda_to_pil_compatible_tensor (image )
21572158
21582159 expected = F .to_image (
@@ -2162,7 +2163,7 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand
21622163 )
21632164
21642165 mae = (actual .float () - expected .float ()).abs ().mean ()
2165- if make_input == make_image_cvcuda :
2166+ if make_input is make_image_cvcuda :
21662167 # CV-CUDA nearest interpolation differs significantly from PIL, set much higher bound
21672168 assert mae < (122.5 ) if interpolation is transforms .InterpolationMode .NEAREST else 6 , f"MAE: { mae } "
21682169 else :
@@ -2202,16 +2203,14 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill,
22022203
22032204 torch .manual_seed (seed )
22042205
2205- if make_input == make_image_cvcuda :
2206- actual = F .cvcuda_to_tensor (actual ).to (device = "cpu" )
2207- image = F .cvcuda_to_tensor (image )
2208- # drop the batch dimensions
2209- image = image .squeeze (0 )
2206+ if make_input is make_image_cvcuda :
2207+ actual = cvcuda_to_pil_compatible_tensor (actual )
2208+ image = cvcuda_to_pil_compatible_tensor (image )
22102209
22112210 expected = F .to_image (transform (F .to_pil_image (image )))
22122211
22132212 mae = (actual .float () - expected .float ()).abs ().mean ()
2214- if make_input == make_image_cvcuda :
2213+ if make_input is make_image_cvcuda :
22152214 # CV-CUDA nearest interpolation differs significantly from PIL, set much higher bound
22162215 assert mae < (122.5 ) if interpolation is transforms .InterpolationMode .NEAREST else 6 , f"MAE: { mae } "
22172216 else :
0 commit comments