@@ -2204,7 +2204,8 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand
22042204
22052205 actual = F .rotate (image , angle = angle , center = center , interpolation = interpolation , expand = expand , fill = fill )
22062206
2207- if make_input == make_image_cvcuda :
2207+ if make_input is make_image_cvcuda :
2208+ actual = cvcuda_to_pil_compatible_tensor (actual )
22082209 image = cvcuda_to_pil_compatible_tensor (image )
22092210
22102211 expected = F .to_image (
@@ -2214,7 +2215,7 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand
22142215 )
22152216
22162217 mae = (actual .float () - expected .float ()).abs ().mean ()
2217- if make_input == make_image_cvcuda :
2218+ if make_input is make_image_cvcuda :
22182219 # CV-CUDA nearest interpolation differs significantly from PIL, set much higher bound
22192220 assert mae < (122.5 ) if interpolation is transforms .InterpolationMode .NEAREST else 6 , f"MAE: { mae } "
22202221 else :
@@ -2254,16 +2255,14 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill,
22542255
22552256 torch .manual_seed (seed )
22562257
2257- if make_input == make_image_cvcuda :
2258- actual = F .cvcuda_to_tensor (actual ).to (device = "cpu" )
2259- image = F .cvcuda_to_tensor (image )
2260- # drop the batch dimensions
2261- image = image .squeeze (0 )
2258+ if make_input is make_image_cvcuda :
2259+ actual = cvcuda_to_pil_compatible_tensor (actual )
2260+ image = cvcuda_to_pil_compatible_tensor (image )
22622261
22632262 expected = F .to_image (transform (F .to_pil_image (image )))
22642263
22652264 mae = (actual .float () - expected .float ()).abs ().mean ()
2266- if make_input == make_image_cvcuda :
2265+ if make_input is make_image_cvcuda :
22672266 # CV-CUDA nearest interpolation differs significantly from PIL, set much higher bound
22682267 assert mae < (122.5 ) if interpolation is transforms .InterpolationMode .NEAREST else 6 , f"MAE: { mae } "
22692268 else :
0 commit comments