diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3ce603c3ed2..2c52c544fe1 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -60,9 +60,11 @@ ) -CVCUDA_AVAILABLE = _is_cvcuda_available() -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() +CV_CUDA_TEST = [ + pytest.mark.skipif(not _is_cvcuda_available(), reason="CVCUDA is not available"), + pytest.mark.needs_cuda, +] + # turns all warnings into errors for this module pytestmark = [pytest.mark.filterwarnings("error")] @@ -1240,10 +1242,7 @@ def test_kernel_video(self): make_image_tensor, make_image_pil, make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1259,11 +1258,7 @@ def test_functional(self, make_input): (F.horizontal_flip_image, torch.Tensor), (F._geometry._horizontal_flip_image_pil, PIL.Image.Image), (F.horizontal_flip_image, tv_tensors.Image), - pytest.param( - F._geometry._horizontal_flip_image_cvcuda, - None, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(F._geometry._horizontal_flip_image_cvcuda, None, marks=CV_CUDA_TEST), (F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes), (F.horizontal_flip_mask, tv_tensors.Mask), (F.horizontal_flip_video, tv_tensors.Video), @@ -1281,10 +1276,7 @@ def test_functional_signature(self, kernel, input_type): make_image_tensor, make_image_pil, make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1302,10 +1294,7 @@ def test_transform(self, make_input, device): "make_input", [ make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), ], ) def test_image_correctness(self, fn, make_input): @@ -1370,10 +1359,7 @@ def test_keypoints_correctness(self, fn): make_image_tensor, make_image_pil, make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1882,10 +1868,7 @@ def test_kernel_video(self): make_image_tensor, make_image_pil, make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1901,11 +1884,7 @@ def test_functional(self, make_input): (F.vertical_flip_image, torch.Tensor), (F._geometry._vertical_flip_image_pil, PIL.Image.Image), (F.vertical_flip_image, tv_tensors.Image), - pytest.param( - F._geometry._vertical_flip_image_cvcuda, - None, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(F._geometry._vertical_flip_image_cvcuda, None, marks=CV_CUDA_TEST), (F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes), (F.vertical_flip_mask, tv_tensors.Mask), (F.vertical_flip_video, tv_tensors.Video), @@ -1923,10 +1902,7 @@ def test_functional_signature(self, kernel, input_type): make_image_tensor, make_image_pil, make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1942,10 +1918,7 @@ def test_transform(self, make_input, device): "make_input", [ make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), ], ) def test_image_correctness(self, fn, make_input): @@ -2006,10 +1979,7 @@ def test_keypoints_correctness(self, fn): make_image_tensor, make_image_pil, make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), make_bounding_boxes, make_segmentation_mask, make_video, @@ -2627,7 +2597,32 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca scale=scale, ) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video]) + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.to_dtype_image, torch.Tensor), + (F.to_dtype_video, tv_tensors.Video), + pytest.param( + F._misc._to_dtype_image_cvcuda, + None, + marks=CV_CUDA_TEST, + ), + ], + ) + def test_functional_signature(self, kernel, input_type): + if kernel is F._misc._to_dtype_image_cvcuda: + input_type = _import_cvcuda().Tensor + check_functional_kernel_signature_match(F.to_dtype, kernel=kernel, input_type=input_type) + + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image, + make_video, + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), + ], + ) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -2642,7 +2637,14 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image, make_bounding_boxes, make_segmentation_mask, make_video], + [ + make_image_tensor, + make_image, + make_bounding_boxes, + make_segmentation_mask, + make_video, + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), + ], ) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) @@ -2688,25 +2690,69 @@ def fn(value): return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device) + def _get_dtype_conversion_atol_cvcuda(self, input_dtype, output_dtype): + in_bits = torch.iinfo(input_dtype).bits if not input_dtype.is_floating_point else None + out_bits = torch.iinfo(output_dtype).bits if not output_dtype.is_floating_point else None + narrows_bits = in_bits is not None and out_bits is not None and out_bits < in_bits + + # int->int with narrowing bits, allow atol=1 for rounding diffs + if narrows_bits: + atol = 1 + # float->int check for same diff, rounding error on float + elif input_dtype.is_floating_point and not output_dtype.is_floating_point: + atol = 1 + # if generating a float value from an int, allow small rounding error + elif not input_dtype.is_floating_point and output_dtype.is_floating_point: + atol = 1e-7 + # all other cases, should be exact + # uint8 -> uint16 promotion would be here + else: + atol = 0 + + return atol + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("scale", (True, False)) - def test_image_correctness(self, input_dtype, output_dtype, device, scale): + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), + ], + ) + @pytest.mark.parametrize("fn", [F.to_dtype, transform_cls_to_functional(transforms.ToDtype)]) + def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_input, fn): if input_dtype.is_floating_point and output_dtype == torch.int64: pytest.xfail("float to int64 conversion is not supported") if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda": pytest.xfail("uint8 to uint16 conversion is not supported on cuda") + if ( + input_dtype == torch.uint16 + and output_dtype == torch.uint8 + and not scale + and make_input is make_image_cvcuda + ): + pytest.xfail("uint16 to uint8 conversion without scale is not supported for CV-CUDA.") - input = make_image(dtype=input_dtype, device=device) + input = make_input(dtype=input_dtype, device=device) + out = fn(input, dtype=output_dtype, scale=scale) + + if make_input is make_image_cvcuda: + input = F.cvcuda_to_tensor(input) + out = F.cvcuda_to_tensor(out) - out = F.to_dtype(input, dtype=output_dtype, scale=scale) expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale) - if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale: - torch.testing.assert_close(out, expected, atol=1, rtol=0) - else: - torch.testing.assert_close(out, expected) + atol, rtol = None, None + if make_input is make_image_cvcuda: + atol = self._get_dtype_conversion_atol_cvcuda(input_dtype, output_dtype) + rtol = 0 + elif input_dtype.is_floating_point and not output_dtype.is_floating_point and scale: + atol, rtol = 1, 0 + + torch.testing.assert_close(out, expected, atol=atol, rtol=rtol) def was_scaled(self, inpt): # this assumes the target dtype is float @@ -6794,9 +6840,9 @@ def test_functional_error(self): F.pil_to_tensor(object()) -@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") -@needs_cuda class TestToCVCUDATensor: + pytestmark = CV_CUDA_TEST + @pytest.mark.parametrize("image_type", (torch.Tensor, tv_tensors.Image)) @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -6813,7 +6859,7 @@ def test_functional_and_transform(self, image_type, dtype, device, color_space, assert is_pure_tensor(image) output = fn(image) - assert isinstance(output, cvcuda.Tensor) + assert isinstance(output, _import_cvcuda().Tensor) assert F.get_size(output) == F.get_size(image) assert output is not None @@ -6856,9 +6902,9 @@ def test_round_trip(self, dtype, device, color_space, batch_size): assert result_tensor.shape[0] == batch_size -@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") -@needs_cuda -class TestCVDUDAToTensor: +class TestCVCUDAToTensor: + pytestmark = CV_CUDA_TEST + @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 305149c87b1..26749c855a4 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -9,6 +9,7 @@ from torchvision import transforms as _transforms, tv_tensors from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2.functional._utils import _is_cvcuda_tensor from ._utils import ( _parse_labels_getter, @@ -267,7 +268,7 @@ class ToDtype(Transform): Default: ``False``. """ - _transformed_types = (torch.Tensor,) + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) def __init__( self, dtype: Union[torch.dtype, dict[Union[type, str], Optional[torch.dtype]]], scale: bool = False @@ -294,7 +295,11 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any: if isinstance(self.dtype, torch.dtype): # For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype # is a simple torch.dtype - if not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)): + if ( + not is_pure_tensor(inpt) + and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) + and not _is_cvcuda_tensor(inpt) + ): return inpt dtype: Optional[torch.dtype] = self.dtype @@ -311,7 +316,9 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any: 'e.g. dtype={tv_tensors.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.' ) - supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) + supports_scaling = ( + is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or _is_cvcuda_tensor(inpt) + ) if dtype is None: if self.scale and supports_scaling: warnings.warn( diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index daf263df046..6ae5466621c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, TYPE_CHECKING import PIL.Image import torch @@ -13,7 +13,12 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] def normalize( @@ -340,6 +345,101 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo return inpt.to(dtype) +# cvcuda is only used if it is installed, so we can simply define empty mappings +_torch_to_cvcuda_dtypes: dict[torch.dtype, "cvcuda.Type"] = {} +_cvcuda_to_torch_dtypes: dict["cvcuda.Type", torch.dtype] = {} + + +def _to_dtype_image_cvcuda( + inpt: "cvcuda.Tensor", + dtype: torch.dtype = torch.float, + scale: bool = False, +) -> "cvcuda.Tensor": + """ + Convert the dtype of a CV-CUDA tensor, based on a torch.dtype. + + Args: + inpt: The CV-CUDA tensor to convert the dtype of. + dtype: The torch.dtype to convert the dtype to. + scale: Whether to scale the values to the new dtype. + There are four cases for the scaling setup: + 1. float -> float + 2. int -> int + 3. float -> int + 4. int -> float + If scale is True, the values will be scaled to the new dtype. + If scale is False, the values will not be scaled. + The scale values for float -> float are 1.0 and 0.0 respectively. + The scale values for int -> int are 2^(bit_diff) of the new dtype. + Where bit_diff is the difference in the number of bits of the new dtype and the input dtype. + The scale values for float -> int and int -> float are the maximum value of the new dtype. + + Returns: + out (cvcuda.Tensor): The CV-CUDA tensor with the converted dtype. + + """ + cvcuda = _import_cvcuda() + + if not _torch_to_cvcuda_dtypes: + _torch_to_cvcuda_dtypes[torch.uint8] = cvcuda.Type.U8 + _torch_to_cvcuda_dtypes[torch.uint16] = cvcuda.Type.U16 + _torch_to_cvcuda_dtypes[torch.uint32] = cvcuda.Type.U32 + _torch_to_cvcuda_dtypes[torch.uint64] = cvcuda.Type.U64 + _torch_to_cvcuda_dtypes[torch.int8] = cvcuda.Type.S8 + _torch_to_cvcuda_dtypes[torch.int16] = cvcuda.Type.S16 + _torch_to_cvcuda_dtypes[torch.int32] = cvcuda.Type.S32 + _torch_to_cvcuda_dtypes[torch.int64] = cvcuda.Type.S64 + _torch_to_cvcuda_dtypes[torch.float32] = cvcuda.Type.F32 + _torch_to_cvcuda_dtypes[torch.float64] = cvcuda.Type.F64 + + if not _cvcuda_to_torch_dtypes: + for k, v in _torch_to_cvcuda_dtypes.items(): + _cvcuda_to_torch_dtypes[v] = k + + dtype_in = _cvcuda_to_torch_dtypes.get(inpt.dtype) + cvc_dtype = _torch_to_cvcuda_dtypes.get(dtype) + if dtype_in is None or cvc_dtype is None: + raise ValueError(f"No torch or cvcuda dtype found for dtype {dtype} or {inpt.dtype}") + + # torchvision will overflow the values of uint16 when converting down to uint8 without scale + # example: 300 -> 255 (cvcuda) vs 300 mod 256 = 44 (torchvision) + # since it is not equivalent, raise an error for unsupported behavior + # the workaround could be using torch for dtype conversion directly via zero-copy + if dtype_in == torch.uint16 and dtype == torch.uint8 and not scale: + raise ValueError("uint16 to uint8 conversion without scale is not supported for CV-CUDA.") + + scale_val, offset = 1.0, 0.0 + if scale: + in_dtype_float = dtype_in.is_floating_point + out_dtype_float = dtype.is_floating_point + + if in_dtype_float and out_dtype_float: + scale_val, offset = 1.0, 0.0 + elif not in_dtype_float and not out_dtype_float: + in_bits = torch.iinfo(dtype_in).bits + out_bits = torch.iinfo(dtype).bits + scale_val = float(2 ** (out_bits - in_bits)) + offset = 0.0 + elif in_dtype_float and not out_dtype_float: + # Mirror the scaling factor which torchvision uses + eps = 1e-3 + max_val = float(_max_value(dtype)) + scale_val, offset = max_val + 1.0 - eps, 0.0 + else: + scale_val, offset = 1.0 / float(_max_value(dtype_in)), 0.0 + + return cvcuda.convertto( + inpt, + dtype=cvc_dtype, + scale=scale_val, + offset=offset, + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(to_dtype, _import_cvcuda().Tensor)(_to_dtype_image_cvcuda) + + def sanitize_bounding_boxes( bounding_boxes: torch.Tensor, format: Optional[tv_tensors.BoundingBoxFormat] = None,