@@ -2656,7 +2656,8 @@ def test_transform(self, make_input, input_dtype, output_dtype, device, scale, a
26562656 output_dtype = {type (input ): output_dtype }
26572657 check_transform (transforms .ToDtype (dtype = output_dtype , scale = scale ), input , check_sample_input = not as_dict )
26582658
2659- def reference_convert_dtype_image_tensor (self , image , dtype = torch .float , scale = False ):
2659+ @staticmethod
2660+ def reference_convert_dtype_image_tensor (image , dtype = torch .float , scale = False ):
26602661 input_dtype = image .dtype
26612662 output_dtype = dtype
26622663
@@ -2807,6 +2808,91 @@ def test_uint16(self):
28072808 assert_close (F .to_dtype (img_uint8 , torch .float32 , scale = True ), img_float32 , rtol = 0 , atol = 1e-2 )
28082809
28092810
2811+ @pytest .mark .skipif (not CVCUDA_AVAILABLE , reason = "cvcuda is not available" )
2812+ @needs_cuda
2813+ class TestToDtypeCVCUDA :
2814+ @pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .float64 , torch .uint8 ])
2815+ @pytest .mark .parametrize ("output_dtype" , [torch .float32 , torch .float64 , torch .uint8 ])
2816+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
2817+ @pytest .mark .parametrize ("scale" , (True , False ))
2818+ def test_functional (self , input_dtype , output_dtype , device , scale ):
2819+ check_functional (
2820+ F .to_dtype ,
2821+ make_image_cvcuda (batch_dims = (1 ,), dtype = input_dtype , device = device ),
2822+ dtype = output_dtype ,
2823+ scale = scale ,
2824+ )
2825+
2826+ @pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .float64 , torch .uint8 ])
2827+ @pytest .mark .parametrize ("output_dtype" , [torch .float32 , torch .float64 , torch .uint8 ])
2828+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
2829+ @pytest .mark .parametrize ("scale" , (True , False ))
2830+ @pytest .mark .parametrize ("as_dict" , (True , False ))
2831+ def test_transform (self , input_dtype , output_dtype , device , scale , as_dict ):
2832+ cvc_input = make_image_cvcuda (batch_dims = (1 ,), dtype = input_dtype , device = device )
2833+ if as_dict :
2834+ output_dtype = {type (cvc_input ): output_dtype }
2835+ check_transform (transforms .ToDtype (dtype = output_dtype , scale = scale ), cvc_input , check_sample_input = not as_dict )
2836+
2837+ @pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .float64 , torch .uint8 , torch .uint16 ])
2838+ @pytest .mark .parametrize ("output_dtype" , [torch .float32 , torch .float64 , torch .uint8 , torch .uint16 ])
2839+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
2840+ @pytest .mark .parametrize ("scale" , (True , False ))
2841+ def test_image_correctness (self , input_dtype , output_dtype , device , scale ):
2842+ if input_dtype .is_floating_point and output_dtype == torch .int64 :
2843+ pytest .xfail ("float to int64 conversion is not supported" )
2844+ if input_dtype == torch .uint8 and output_dtype == torch .uint16 and device == "cuda" :
2845+ pytest .xfail ("uint8 to uint16 conversion is not supported on cuda" )
2846+ if input_dtype == torch .uint8 and output_dtype == torch .uint16 and scale :
2847+ pytest .xfail ("uint8 to uint16 conversion with scale is not supported in F.to_dtype_image" )
2848+
2849+ cvc_input = make_image_cvcuda (batch_dims = (1 ,), dtype = input_dtype , device = device )
2850+ torch_input = F .cvcuda_to_tensor (cvc_input )
2851+
2852+ out = F .to_dtype (cvc_input , dtype = output_dtype , scale = scale )
2853+ out = F .cvcuda_to_tensor (out )
2854+
2855+ expected = F .to_dtype (torch_input , dtype = output_dtype , scale = scale )
2856+
2857+ # there are some differences in dtype conversion between torchvision and cvcuda
2858+ # due to different rounding behavior when converting between types with different bit widths
2859+ # Check if we're converting to a type with more bits (without scaling)
2860+ in_bits = torch .iinfo (input_dtype ).bits if not input_dtype .is_floating_point else None
2861+ out_bits = torch .iinfo (output_dtype ).bits if not output_dtype .is_floating_point else None
2862+
2863+ if scale :
2864+ if input_dtype .is_floating_point and not output_dtype .is_floating_point :
2865+ # float -> int with scaling: allow for rounding differences
2866+ torch .testing .assert_close (out , expected , atol = 1 , rtol = 0 )
2867+ elif input_dtype == torch .uint16 and output_dtype == torch .uint8 :
2868+ # uint16 -> uint8 with scaling: allow large differences
2869+ torch .testing .assert_close (out , expected , atol = 255 , rtol = 0 )
2870+ else :
2871+ torch .testing .assert_close (out , expected )
2872+ else :
2873+ if in_bits is not None and out_bits is not None and out_bits > in_bits :
2874+ # uint to larger uint without scaling: allow large differences due to bit expansion
2875+ if input_dtype == torch .uint8 and output_dtype == torch .uint16 :
2876+ torch .testing .assert_close (out , expected , atol = 255 , rtol = 0 )
2877+ else :
2878+ torch .testing .assert_close (out , expected , atol = 1 , rtol = 0 )
2879+ elif not input_dtype .is_floating_point and not output_dtype .is_floating_point :
2880+ # uint to uint without scaling (same or smaller bits): allow for rounding
2881+ if input_dtype == torch .uint16 and output_dtype == torch .uint8 :
2882+ # uint16 -> uint8 can have large differences due to bit reduction
2883+ torch .testing .assert_close (out , expected , atol = 255 , rtol = 0 )
2884+ else :
2885+ torch .testing .assert_close (out , expected )
2886+ elif input_dtype .is_floating_point and not output_dtype .is_floating_point :
2887+ # float -> uint without scaling: allow for rounding differences
2888+ torch .testing .assert_close (out , expected , atol = 1 , rtol = 0 )
2889+ elif not input_dtype .is_floating_point and output_dtype .is_floating_point :
2890+ # uint -> float without scaling: allow for rounding differences
2891+ torch .testing .assert_close (out , expected , atol = 1 , rtol = 0 )
2892+ else :
2893+ torch .testing .assert_close (out , expected )
2894+
2895+
28102896class TestAdjustBrightness :
28112897 _CORRECTNESS_BRIGHTNESS_FACTORS = [0.5 , 0.0 , 1.0 , 5.0 ]
28122898 _DEFAULT_BRIGHTNESS_FACTOR = _CORRECTNESS_BRIGHTNESS_FACTORS [0 ]
0 commit comments