From 44db71c0772e5ef5758c38d0e4e8ad9995946c80 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:14:49 -0800 Subject: [PATCH 01/38] implement additional cvcuda infra for all branches to avoid duplicate setup --- torchvision/transforms/v2/_transform.py | 4 ++-- torchvision/transforms/v2/_utils.py | 3 ++- .../transforms/v2/functional/__init__.py | 2 +- .../transforms/v2/functional/_augment.py | 11 ++++++++++- .../transforms/v2/functional/_color.py | 12 +++++++++++- .../transforms/v2/functional/_geometry.py | 19 +++++++++++++++++-- torchvision/transforms/v2/functional/_misc.py | 11 +++++++++-- .../transforms/v2/functional/_utils.py | 16 ++++++++++++++++ 8 files changed, 68 insertions(+), 10 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..bec9ffcf714 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel +from .functional._utils import _get_kernel, is_cvcuda_tensor class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) def __init__(self) -> None: super().__init__() diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..765a772fe41 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,7 +15,7 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 032a993b1f0..52181e4624b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel # usort: skip +from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index a904d8d7cbd..7ce5bdc7b7e 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,4 +1,5 @@ import io +from typing import TYPE_CHECKING import PIL.Image @@ -8,7 +9,15 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def erase( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index be254c0d63a..5be9c62902a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import PIL.Image import torch from torch.nn.functional import conv2d @@ -9,7 +11,15 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4fcb7fabe0d..c029488001c 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2,7 +2,7 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import PIL.Image import torch @@ -26,7 +26,22 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import ( + _FillTypeJIT, + _get_kernel, + _import_cvcuda, + _is_cvcuda_available, + _register_five_ten_crop_kernel_internal, + _register_kernel_internal, +) + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index daf263df046..0fa05a2113c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, TYPE_CHECKING import PIL.Image import torch @@ -13,7 +13,14 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def normalize( diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ad1eddd258b..73fafaf7425 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -169,3 +169,19 @@ def _is_cvcuda_available(): return True except ImportError: return False + + +def is_cvcuda_tensor(inpt: Any) -> bool: + """ + Check if the input is a CVCUDA tensor. + + Args: + inpt: The input to check. + + Returns: + True if the input is a CV-CUDA tensor, False otherwise. + """ + if _is_cvcuda_available(): + cvcuda = _import_cvcuda() + return isinstance(inpt, cvcuda.Tensor) + return False From e3dd70022fa1c87aca7a9a98068b6e13e802a375 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:26:19 -0800 Subject: [PATCH 02/38] update make_image_cvcuda to have default batch dim --- test/common_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 8c3c9dd58a8..e7bae60c41b 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -400,8 +400,9 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_image_cvcuda(*args, **kwargs): - return to_cvcuda_tensor(make_image(*args, **kwargs)) +def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): + # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) + return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): From c035df1c6eaebcad25604f8c298a7d9eaf86864b Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:16:27 -0800 Subject: [PATCH 03/38] add stanardized setup to main for easier updating of PRs and branches --- test/common_utils.py | 21 ++++++++++++++-- test/test_transforms_v2.py | 2 +- torchvision/transforms/v2/_utils.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 24 +++++++++++++++++-- 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index e7bae60c41b..3b889e93d2e 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,13 +20,15 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available from torchvision.utils import _Image_fromarray IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" +CVCUDA_AVAILABLE = _is_cvcuda_available() CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -275,6 +277,17 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] +def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: + tensor = cvcuda_to_tensor(tensor) + if tensor.ndim != 4: + raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") + if tensor.shape[0] != 1: + raise ValueError( + f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." + ) + return tensor.squeeze(0).cpu() + + class ImagePair(TensorLikePair): def __init__( self, @@ -287,6 +300,11 @@ def __init__( if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): actual, expected = (to_image(input) for input in [actual, expected]) + # handle check for CV-CUDA Tensors + if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor): + # Use the PIL compatible tensor, so we can always compare with PIL.Image.Image + actual = cvcuda_to_pil_compatible_tensor(actual) + super().__init__(actual, expected, **other_parameters) self.mae = mae @@ -401,7 +419,6 @@ def make_image_pil(*args, **kwargs): def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): - # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 670a9d00ffb..7eba65550da 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 765a772fe41..3fc33ce5964 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6b8f19f12f4..ee562cb2aee 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) +def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: + # CV-CUDA tensor is always in NHWC layout + # get_dimensions is CHW + return [image.shape[3], image.shape[1], image.shape[2]] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + + def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image(inpt) @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels +def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: + # CV-CUDA tensor is always in NHWC layout + # get_num_channels is C + return image.shape[3] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + + def get_size(inpt: torch.Tensor) -> list[int]: if torch.jit.is_scripting(): return get_size_image(inpt) @@ -114,7 +134,7 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]: return [height, width] -def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def _get_size_cvcuda(image: "cvcuda.Tensor") -> list[int]: """Get size of `cvcuda.Tensor` with NHWC layout.""" hw = list(image.shape[-3:-1]) ndims = len(hw) @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: if CVCUDA_AVAILABLE: - _get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda) + _register_kernel_internal(get_size, cvcuda.Tensor)(_get_size_cvcuda) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False) From 98d7dfb2059eaf2c10c3f549ea45f1d27875134c Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:25:09 -0800 Subject: [PATCH 04/38] update is_cvcuda_tensor --- torchvision/transforms/v2/functional/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 73fafaf7425..44b2edeaf2d 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -181,7 +181,8 @@ def is_cvcuda_tensor(inpt: Any) -> bool: Returns: True if the input is a CV-CUDA tensor, False otherwise. """ - if _is_cvcuda_available(): + try: cvcuda = _import_cvcuda() return isinstance(inpt, cvcuda.Tensor) - return False + except ImportError: + return False From ddc116d13febdae1d53507bcde9f103a4c14eba7 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:37:03 -0800 Subject: [PATCH 05/38] add cvcuda to pil compatible to transforms by default --- test/test_transforms_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7eba65550da..87166477669 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -25,6 +25,7 @@ assert_equal, cache, cpu_and_cuda, + cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, From e51dc7eabd254261347245f4492892fd0944aae5 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:46:23 -0800 Subject: [PATCH 06/38] remove cvcuda from transform class --- torchvision/transforms/v2/_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index bec9ffcf714..ac84fcb6c82 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel, is_cvcuda_tensor +from .functional._utils import _get_kernel class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) def __init__(self) -> None: super().__init__() From 4939355a2c7421eeba95d7f155fe7953066aec6d Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:07:08 -0800 Subject: [PATCH 07/38] resolve more formatting naming --- torchvision/transforms/v2/functional/__init__.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 52181e4624b..032a993b1f0 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip +from ._utils import is_pure_tensor, register_kernel # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index e8630f788ca..af03ad018d4 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,14 +51,14 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) -def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: # CV-CUDA tensor is always in NHWC layout # get_dimensions is CHW return [image.shape[3], image.shape[1], image.shape[2]] if CVCUDA_AVAILABLE: - _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda) def get_num_channels(inpt: torch.Tensor) -> int: @@ -97,14 +97,14 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels -def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: +def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int: # CV-CUDA tensor is always in NHWC layout # get_num_channels is C return image.shape[3] if CVCUDA_AVAILABLE: - _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda) def get_size(inpt: torch.Tensor) -> list[int]: From ec76196d23adfd7202e9f20f430f78f94cd930d2 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 18 Nov 2025 11:19:47 -0800 Subject: [PATCH 08/38] initial draft of to_dtype_cvcuda --- test/test_transforms_v2.py | 88 ++++++++++++++++++- .../transforms/v2/functional/__init__.py | 1 + torchvision/transforms/v2/functional/_misc.py | 65 ++++++++++++++ 3 files changed, 153 insertions(+), 1 deletion(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f767e211125..9c956535a82 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2656,7 +2656,8 @@ def test_transform(self, make_input, input_dtype, output_dtype, device, scale, a output_dtype = {type(input): output_dtype} check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input, check_sample_input=not as_dict) - def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False): + @staticmethod + def reference_convert_dtype_image_tensor(image, dtype=torch.float, scale=False): input_dtype = image.dtype output_dtype = dtype @@ -2807,6 +2808,91 @@ def test_uint16(self): assert_close(F.to_dtype(img_uint8, torch.float32, scale=True), img_float32, rtol=0, atol=1e-2) +@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda is not available") +@needs_cuda +class TestToDtypeCVCUDA: + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("scale", (True, False)) + def test_functional(self, input_dtype, output_dtype, device, scale): + check_functional( + F.to_dtype, + make_image_cvcuda(batch_dims=(1,), dtype=input_dtype, device=device), + dtype=output_dtype, + scale=scale, + ) + + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("scale", (True, False)) + @pytest.mark.parametrize("as_dict", (True, False)) + def test_transform(self, input_dtype, output_dtype, device, scale, as_dict): + cvc_input = make_image_cvcuda(batch_dims=(1,), dtype=input_dtype, device=device) + if as_dict: + output_dtype = {type(cvc_input): output_dtype} + check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), cvc_input, check_sample_input=not as_dict) + + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) + @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("scale", (True, False)) + def test_image_correctness(self, input_dtype, output_dtype, device, scale): + if input_dtype.is_floating_point and output_dtype == torch.int64: + pytest.xfail("float to int64 conversion is not supported") + if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda": + pytest.xfail("uint8 to uint16 conversion is not supported on cuda") + if input_dtype == torch.uint8 and output_dtype == torch.uint16 and scale: + pytest.xfail("uint8 to uint16 conversion with scale is not supported in F.to_dtype_image") + + cvc_input = make_image_cvcuda(batch_dims=(1,), dtype=input_dtype, device=device) + torch_input = F.cvcuda_to_tensor(cvc_input) + + out = F.to_dtype(cvc_input, dtype=output_dtype, scale=scale) + out = F.cvcuda_to_tensor(out) + + expected = F.to_dtype(torch_input, dtype=output_dtype, scale=scale) + + # there are some differences in dtype conversion between torchvision and cvcuda + # due to different rounding behavior when converting between types with different bit widths + # Check if we're converting to a type with more bits (without scaling) + in_bits = torch.iinfo(input_dtype).bits if not input_dtype.is_floating_point else None + out_bits = torch.iinfo(output_dtype).bits if not output_dtype.is_floating_point else None + + if scale: + if input_dtype.is_floating_point and not output_dtype.is_floating_point: + # float -> int with scaling: allow for rounding differences + torch.testing.assert_close(out, expected, atol=1, rtol=0) + elif input_dtype == torch.uint16 and output_dtype == torch.uint8: + # uint16 -> uint8 with scaling: allow large differences + torch.testing.assert_close(out, expected, atol=255, rtol=0) + else: + torch.testing.assert_close(out, expected) + else: + if in_bits is not None and out_bits is not None and out_bits > in_bits: + # uint to larger uint without scaling: allow large differences due to bit expansion + if input_dtype == torch.uint8 and output_dtype == torch.uint16: + torch.testing.assert_close(out, expected, atol=255, rtol=0) + else: + torch.testing.assert_close(out, expected, atol=1, rtol=0) + elif not input_dtype.is_floating_point and not output_dtype.is_floating_point: + # uint to uint without scaling (same or smaller bits): allow for rounding + if input_dtype == torch.uint16 and output_dtype == torch.uint8: + # uint16 -> uint8 can have large differences due to bit reduction + torch.testing.assert_close(out, expected, atol=255, rtol=0) + else: + torch.testing.assert_close(out, expected) + elif input_dtype.is_floating_point and not output_dtype.is_floating_point: + # float -> uint without scaling: allow for rounding differences + torch.testing.assert_close(out, expected, atol=1, rtol=0) + elif not input_dtype.is_floating_point and output_dtype.is_floating_point: + # uint -> float without scaling: allow for rounding differences + torch.testing.assert_close(out, expected, atol=1, rtol=0) + else: + torch.testing.assert_close(out, expected) + + class TestAdjustBrightness: _CORRECTNESS_BRIGHTNESS_FACTORS = [0.5, 0.0, 1.0, 5.0] _DEFAULT_BRIGHTNESS_FACTOR = _CORRECTNESS_BRIGHTNESS_FACTORS[0] diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 032a993b1f0..58f2d732dc3 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -158,6 +158,7 @@ sanitize_bounding_boxes, sanitize_keypoints, to_dtype, + to_dtype_cvcuda, to_dtype_image, to_dtype_video, ) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 0fa05a2113c..5ee924de2e7 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -347,6 +347,71 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo return inpt.to(dtype) +# cvcuda is only used if it is installed, so we can simply define empty mappings +_torch_to_cvcuda_dtypes = {} +_cvcuda_to_torch_dtypes = {} +if CVCUDA_AVAILABLE: + # put the entire conversion set here + # only a subset are used for torchvision + _torch_to_cvcuda_dtypes = { + torch.uint8: cvcuda.Type.U8, + torch.uint16: cvcuda.Type.U16, + torch.uint32: cvcuda.Type.U32, + torch.uint64: cvcuda.Type.U64, + torch.int8: cvcuda.Type.S8, + torch.int16: cvcuda.Type.S16, + torch.int32: cvcuda.Type.S32, + torch.int64: cvcuda.Type.S64, + torch.float32: cvcuda.Type.F32, + torch.float64: cvcuda.Type.F64, + torch.complex64: cvcuda.Type.C64, + torch.complex128: cvcuda.Type.C128, + } + # create reverse mapping + _cvcuda_to_torch_dtypes = {v: k for k, v in _torch_to_cvcuda_dtypes.items()} + + +def to_dtype_cvcuda( + inpt: "cvcuda.Tensor", + dtype: torch.dtype, + scale: bool = False, +) -> "cvcuda.Tensor": + dtype_in = _cvcuda_to_torch_dtypes[inpt.dtype] + cvc_dtype = _torch_to_cvcuda_dtypes[dtype] + + if not scale: + return cvcuda.convertto(inpt, dtype=cvc_dtype) + + scale_val, offset = 1.0, 0.0 + in_dtype_float = dtype_in.is_floating_point + out_dtype_float = dtype.is_floating_point + + # four cases for the scaling setup + # 1. float -> float + # 2. int -> int + # 3. float -> int + # 4. int -> float + if in_dtype_float and out_dtype_float: + scale_val, offset = 1.0, 0.0 + elif not in_dtype_float and not out_dtype_float: + scale_val, offset = 1.0, 0.0 + elif in_dtype_float and not out_dtype_float: + scale_val, offset = float(_max_value(dtype)), 0.0 + else: + scale_val, offset = 1.0 / float(_max_value(dtype_in)), 0.0 + + return cvcuda.convertto( + inpt, + dtype=cvc_dtype, + scale=scale_val, + offset=offset, + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(to_dtype, cvcuda.Tensor)(to_dtype_cvcuda) + + def sanitize_bounding_boxes( bounding_boxes: torch.Tensor, format: Optional[tv_tensors.BoundingBoxFormat] = None, From bd823cff637fc576b30258473049b3c9700cd11c Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 20 Nov 2025 12:38:26 -0800 Subject: [PATCH 09/38] fix: to_dtype_cvcuda conventions --- test/test_transforms_v2.py | 53 ++++++++----------- .../transforms/v2/functional/__init__.py | 1 - torchvision/transforms/v2/functional/_misc.py | 6 ++- 3 files changed, 27 insertions(+), 33 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 9c956535a82..3603829c68c 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2628,7 +2628,17 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca scale=scale, ) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -2643,7 +2653,16 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image, make_bounding_boxes, make_segmentation_mask, make_video], + [ + make_image_tensor, + make_image, + make_bounding_boxes, + make_segmentation_mask, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], ) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) @@ -2807,38 +2826,12 @@ def test_uint16(self): assert_equal(F.to_dtype(img_float32, torch.uint8, scale=True), img_uint8) assert_close(F.to_dtype(img_uint8, torch.float32, scale=True), img_float32, rtol=0, atol=1e-2) - -@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda is not available") -@needs_cuda -class TestToDtypeCVCUDA: - @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) - @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) - @pytest.mark.parametrize("device", cpu_and_cuda()) - @pytest.mark.parametrize("scale", (True, False)) - def test_functional(self, input_dtype, output_dtype, device, scale): - check_functional( - F.to_dtype, - make_image_cvcuda(batch_dims=(1,), dtype=input_dtype, device=device), - dtype=output_dtype, - scale=scale, - ) - - @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) - @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) - @pytest.mark.parametrize("device", cpu_and_cuda()) - @pytest.mark.parametrize("scale", (True, False)) - @pytest.mark.parametrize("as_dict", (True, False)) - def test_transform(self, input_dtype, output_dtype, device, scale, as_dict): - cvc_input = make_image_cvcuda(batch_dims=(1,), dtype=input_dtype, device=device) - if as_dict: - output_dtype = {type(cvc_input): output_dtype} - check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), cvc_input, check_sample_input=not as_dict) - + @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("scale", (True, False)) - def test_image_correctness(self, input_dtype, output_dtype, device, scale): + def test_cvcuda_parity(self, input_dtype, output_dtype, device, scale): if input_dtype.is_floating_point and output_dtype == torch.int64: pytest.xfail("float to int64 conversion is not supported") if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda": diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 58f2d732dc3..032a993b1f0 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -158,7 +158,6 @@ sanitize_bounding_boxes, sanitize_keypoints, to_dtype, - to_dtype_cvcuda, to_dtype_image, to_dtype_video, ) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 5ee924de2e7..573adfeefbc 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -371,11 +371,13 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo _cvcuda_to_torch_dtypes = {v: k for k, v in _torch_to_cvcuda_dtypes.items()} -def to_dtype_cvcuda( +def _to_dtype_cvcuda( inpt: "cvcuda.Tensor", dtype: torch.dtype, scale: bool = False, ) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + dtype_in = _cvcuda_to_torch_dtypes[inpt.dtype] cvc_dtype = _torch_to_cvcuda_dtypes[dtype] @@ -409,7 +411,7 @@ def to_dtype_cvcuda( if CVCUDA_AVAILABLE: - _register_kernel_internal(to_dtype, cvcuda.Tensor)(to_dtype_cvcuda) + _to_dtype_cvcuda_registered = _register_kernel_internal(to_dtype, _import_cvcuda().Tensor)(_to_dtype_cvcuda) def sanitize_bounding_boxes( From f7aa94a891a1e38da4ec5c3db6c25095fbf87bde Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 24 Nov 2025 08:19:39 -0800 Subject: [PATCH 10/38] remove staticmethod from reference todtype --- test/test_transforms_v2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3603829c68c..099d89cef84 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2675,8 +2675,7 @@ def test_transform(self, make_input, input_dtype, output_dtype, device, scale, a output_dtype = {type(input): output_dtype} check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input, check_sample_input=not as_dict) - @staticmethod - def reference_convert_dtype_image_tensor(image, dtype=torch.float, scale=False): + def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False): input_dtype = image.dtype output_dtype = dtype From b21d9f00911b6f56c2371fd77c2018ae47680c87 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 24 Nov 2025 08:30:21 -0800 Subject: [PATCH 11/38] add docstring for explain scaling setup, combine correctness checks --- test/test_transforms_v2.py | 128 +++++++++--------- torchvision/transforms/v2/functional/_misc.py | 21 +++ 2 files changed, 82 insertions(+), 67 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 099d89cef84..88d5dbe8431 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2712,21 +2712,74 @@ def fn(value): @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("scale", (True, False)) - def test_image_correctness(self, input_dtype, output_dtype, device, scale): + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) + def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_input): if input_dtype.is_floating_point and output_dtype == torch.int64: pytest.xfail("float to int64 conversion is not supported") if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda": pytest.xfail("uint8 to uint16 conversion is not supported on cuda") + if input_dtype == torch.uint8 and output_dtype == torch.uint16 and scale and make_input == make_image_cvcuda: + pytest.xfail("uint8 to uint16 conversion with scale is not supported in F._misc._to_dtype_cvcuda") - input = make_image(dtype=input_dtype, device=device) - + input = make_input(dtype=input_dtype, device=device) out = F.to_dtype(input, dtype=output_dtype, scale=scale) - expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale) - if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale: - torch.testing.assert_close(out, expected, atol=1, rtol=0) - else: - torch.testing.assert_close(out, expected) + if isinstance(input, torch.Tensor): + expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale) + if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale: + torch.testing.assert_close(out, expected, atol=1, rtol=0) + else: + torch.testing.assert_close(out, expected) + else: # cvcuda.Tensor + expected = self.reference_convert_dtype_image_tensor( + F.cvcuda_to_tensor(input), dtype=output_dtype, scale=scale + ) + out = F.cvcuda_to_tensor(out) + # there are some differences in dtype conversion between torchvision and cvcuda + # due to different rounding behavior when converting between types with different bit widths + # Check if we're converting to a type with more bits (without scaling) + in_bits = torch.iinfo(input_dtype).bits if not input_dtype.is_floating_point else None + out_bits = torch.iinfo(output_dtype).bits if not output_dtype.is_floating_point else None + + if scale: + if input_dtype.is_floating_point and not output_dtype.is_floating_point: + # float -> int with scaling: allow for rounding differences + torch.testing.assert_close(out, expected, atol=1, rtol=0) + elif input_dtype == torch.uint16 and output_dtype == torch.uint8: + # uint16 -> uint8 with scaling: allow large differences + torch.testing.assert_close(out, expected, atol=255, rtol=0) + else: + torch.testing.assert_close(out, expected) + else: + if in_bits is not None and out_bits is not None and out_bits > in_bits: + # uint to larger uint without scaling: allow large differences due to bit expansion + if input_dtype == torch.uint8 and output_dtype == torch.uint16: + torch.testing.assert_close(out, expected, atol=255, rtol=0) + else: + torch.testing.assert_close(out, expected, atol=1, rtol=0) + elif not input_dtype.is_floating_point and not output_dtype.is_floating_point: + # uint to uint without scaling (same or smaller bits): allow for rounding + if input_dtype == torch.uint16 and output_dtype == torch.uint8: + # uint16 -> uint8 can have large differences due to bit reduction + torch.testing.assert_close(out, expected, atol=255, rtol=0) + else: + torch.testing.assert_close(out, expected) + elif input_dtype.is_floating_point and not output_dtype.is_floating_point: + # float -> uint without scaling: allow for rounding differences + torch.testing.assert_close(out, expected, atol=1, rtol=0) + elif not input_dtype.is_floating_point and output_dtype.is_floating_point: + # uint -> float without scaling: allow for rounding differences + torch.testing.assert_close(out, expected, atol=1, rtol=0) + else: + torch.testing.assert_close(out, expected) def was_scaled(self, inpt): # this assumes the target dtype is float @@ -2825,65 +2878,6 @@ def test_uint16(self): assert_equal(F.to_dtype(img_float32, torch.uint8, scale=True), img_uint8) assert_close(F.to_dtype(img_uint8, torch.float32, scale=True), img_float32, rtol=0, atol=1e-2) - @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") - @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) - @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) - @pytest.mark.parametrize("device", cpu_and_cuda()) - @pytest.mark.parametrize("scale", (True, False)) - def test_cvcuda_parity(self, input_dtype, output_dtype, device, scale): - if input_dtype.is_floating_point and output_dtype == torch.int64: - pytest.xfail("float to int64 conversion is not supported") - if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda": - pytest.xfail("uint8 to uint16 conversion is not supported on cuda") - if input_dtype == torch.uint8 and output_dtype == torch.uint16 and scale: - pytest.xfail("uint8 to uint16 conversion with scale is not supported in F.to_dtype_image") - - cvc_input = make_image_cvcuda(batch_dims=(1,), dtype=input_dtype, device=device) - torch_input = F.cvcuda_to_tensor(cvc_input) - - out = F.to_dtype(cvc_input, dtype=output_dtype, scale=scale) - out = F.cvcuda_to_tensor(out) - - expected = F.to_dtype(torch_input, dtype=output_dtype, scale=scale) - - # there are some differences in dtype conversion between torchvision and cvcuda - # due to different rounding behavior when converting between types with different bit widths - # Check if we're converting to a type with more bits (without scaling) - in_bits = torch.iinfo(input_dtype).bits if not input_dtype.is_floating_point else None - out_bits = torch.iinfo(output_dtype).bits if not output_dtype.is_floating_point else None - - if scale: - if input_dtype.is_floating_point and not output_dtype.is_floating_point: - # float -> int with scaling: allow for rounding differences - torch.testing.assert_close(out, expected, atol=1, rtol=0) - elif input_dtype == torch.uint16 and output_dtype == torch.uint8: - # uint16 -> uint8 with scaling: allow large differences - torch.testing.assert_close(out, expected, atol=255, rtol=0) - else: - torch.testing.assert_close(out, expected) - else: - if in_bits is not None and out_bits is not None and out_bits > in_bits: - # uint to larger uint without scaling: allow large differences due to bit expansion - if input_dtype == torch.uint8 and output_dtype == torch.uint16: - torch.testing.assert_close(out, expected, atol=255, rtol=0) - else: - torch.testing.assert_close(out, expected, atol=1, rtol=0) - elif not input_dtype.is_floating_point and not output_dtype.is_floating_point: - # uint to uint without scaling (same or smaller bits): allow for rounding - if input_dtype == torch.uint16 and output_dtype == torch.uint8: - # uint16 -> uint8 can have large differences due to bit reduction - torch.testing.assert_close(out, expected, atol=255, rtol=0) - else: - torch.testing.assert_close(out, expected) - elif input_dtype.is_floating_point and not output_dtype.is_floating_point: - # float -> uint without scaling: allow for rounding differences - torch.testing.assert_close(out, expected, atol=1, rtol=0) - elif not input_dtype.is_floating_point and output_dtype.is_floating_point: - # uint -> float without scaling: allow for rounding differences - torch.testing.assert_close(out, expected, atol=1, rtol=0) - else: - torch.testing.assert_close(out, expected) - class TestAdjustBrightness: _CORRECTNESS_BRIGHTNESS_FACTORS = [0.5, 0.0, 1.0, 5.0] diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 573adfeefbc..21a87cab14b 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -376,6 +376,27 @@ def _to_dtype_cvcuda( dtype: torch.dtype, scale: bool = False, ) -> "cvcuda.Tensor": + """ + Convert the dtype of a CV-CUDA tensor, based on a torch.dtype. + + Args: + inpt: The CV-CUDA tensor to convert the dtype of. + dtype: The torch.dtype to convert the dtype to. + scale: Whether to scale the values to the new dtype. + There are four cases for the scaling setup: + 1. float -> float + 2. int -> int + 3. float -> int + 4. int -> float + If scale is True, the values will be scaled to the new dtype. + If scale is False, the values will not be scaled. + The scale values for float -> float and int -> int are 1.0 and 0.0 respectively. + The scale values for float -> int and int -> float are the maximum value of the new dtype. + + Returns: + out (cvcuda.Tensor): The CV-CUDA tensor with the converted dtype. + + """ cvcuda = _import_cvcuda() dtype_in = _cvcuda_to_torch_dtypes[inpt.dtype] From 973e058f590d87e62a158d1676e7bc4414933835 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 24 Nov 2025 08:34:18 -0800 Subject: [PATCH 12/38] resolve more review comments --- torchvision/transforms/v2/functional/_misc.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 21a87cab14b..4d6c10e6b66 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -364,8 +364,6 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo torch.int64: cvcuda.Type.S64, torch.float32: cvcuda.Type.F32, torch.float64: cvcuda.Type.F64, - torch.complex64: cvcuda.Type.C64, - torch.complex128: cvcuda.Type.C128, } # create reverse mapping _cvcuda_to_torch_dtypes = {v: k for k, v in _torch_to_cvcuda_dtypes.items()} @@ -414,9 +412,7 @@ def _to_dtype_cvcuda( # 2. int -> int # 3. float -> int # 4. int -> float - if in_dtype_float and out_dtype_float: - scale_val, offset = 1.0, 0.0 - elif not in_dtype_float and not out_dtype_float: + if in_dtype_float == out_dtype_float: scale_val, offset = 1.0, 0.0 elif in_dtype_float and not out_dtype_float: scale_val, offset = float(_max_value(dtype)), 0.0 @@ -432,7 +428,7 @@ def _to_dtype_cvcuda( if CVCUDA_AVAILABLE: - _to_dtype_cvcuda_registered = _register_kernel_internal(to_dtype, _import_cvcuda().Tensor)(_to_dtype_cvcuda) + _register_kernel_internal(to_dtype, _import_cvcuda().Tensor)(_to_dtype_cvcuda) def sanitize_bounding_boxes( From d871331af3867e42a38aab6188e6a9357dcbec65 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Wed, 26 Nov 2025 09:49:31 -0800 Subject: [PATCH 13/38] simplify todtype testing --- test/test_transforms_v2.py | 72 ++++++++----------- torchvision/transforms/v2/functional/_misc.py | 27 +++---- 2 files changed, 38 insertions(+), 61 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 88d5dbe8431..f94dc7c6008 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2708,6 +2708,28 @@ def fn(value): return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device) + def _get_dtype_conversion_atol(self, input_dtype, output_dtype, scale): + is_uint16_to_uint8 = input_dtype == torch.uint16 and output_dtype == torch.uint8 + is_uint8_to_uint16 = input_dtype == torch.uint8 and output_dtype == torch.uint16 + changes_type_class = output_dtype.is_floating_point != input_dtype.is_floating_point + + in_bits = torch.iinfo(input_dtype).bits if not input_dtype.is_floating_point else None + out_bits = torch.iinfo(output_dtype).bits if not output_dtype.is_floating_point else None + expands_bits = in_bits is not None and out_bits is not None and out_bits > in_bits + + if is_uint16_to_uint8: + atol = 255 + elif is_uint8_to_uint16 and not scale: + atol = 255 + elif expands_bits and not scale: + atol = 1 + elif changes_type_class: + atol = 1 + else: + atol = 0 + + return atol + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -2732,54 +2754,16 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_ input = make_input(dtype=input_dtype, device=device) out = F.to_dtype(input, dtype=output_dtype, scale=scale) - if isinstance(input, torch.Tensor): - expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale) - if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale: - torch.testing.assert_close(out, expected, atol=1, rtol=0) - else: - torch.testing.assert_close(out, expected) - else: # cvcuda.Tensor + if make_input == make_image_cvcuda: expected = self.reference_convert_dtype_image_tensor( F.cvcuda_to_tensor(input), dtype=output_dtype, scale=scale ) out = F.cvcuda_to_tensor(out) - # there are some differences in dtype conversion between torchvision and cvcuda - # due to different rounding behavior when converting between types with different bit widths - # Check if we're converting to a type with more bits (without scaling) - in_bits = torch.iinfo(input_dtype).bits if not input_dtype.is_floating_point else None - out_bits = torch.iinfo(output_dtype).bits if not output_dtype.is_floating_point else None - - if scale: - if input_dtype.is_floating_point and not output_dtype.is_floating_point: - # float -> int with scaling: allow for rounding differences - torch.testing.assert_close(out, expected, atol=1, rtol=0) - elif input_dtype == torch.uint16 and output_dtype == torch.uint8: - # uint16 -> uint8 with scaling: allow large differences - torch.testing.assert_close(out, expected, atol=255, rtol=0) - else: - torch.testing.assert_close(out, expected) - else: - if in_bits is not None and out_bits is not None and out_bits > in_bits: - # uint to larger uint without scaling: allow large differences due to bit expansion - if input_dtype == torch.uint8 and output_dtype == torch.uint16: - torch.testing.assert_close(out, expected, atol=255, rtol=0) - else: - torch.testing.assert_close(out, expected, atol=1, rtol=0) - elif not input_dtype.is_floating_point and not output_dtype.is_floating_point: - # uint to uint without scaling (same or smaller bits): allow for rounding - if input_dtype == torch.uint16 and output_dtype == torch.uint8: - # uint16 -> uint8 can have large differences due to bit reduction - torch.testing.assert_close(out, expected, atol=255, rtol=0) - else: - torch.testing.assert_close(out, expected) - elif input_dtype.is_floating_point and not output_dtype.is_floating_point: - # float -> uint without scaling: allow for rounding differences - torch.testing.assert_close(out, expected, atol=1, rtol=0) - elif not input_dtype.is_floating_point and output_dtype.is_floating_point: - # uint -> float without scaling: allow for rounding differences - torch.testing.assert_close(out, expected, atol=1, rtol=0) - else: - torch.testing.assert_close(out, expected) + else: + expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale) + + atol = self._get_dtype_conversion_atol(input_dtype, output_dtype, scale) + torch.testing.assert_close(out, expected, rtol=0, atol=atol) def was_scaled(self, inpt): # this assumes the target dtype is float diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 4d6c10e6b66..bb5fac190a6 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -400,24 +400,17 @@ def _to_dtype_cvcuda( dtype_in = _cvcuda_to_torch_dtypes[inpt.dtype] cvc_dtype = _torch_to_cvcuda_dtypes[dtype] - if not scale: - return cvcuda.convertto(inpt, dtype=cvc_dtype) - scale_val, offset = 1.0, 0.0 - in_dtype_float = dtype_in.is_floating_point - out_dtype_float = dtype.is_floating_point - - # four cases for the scaling setup - # 1. float -> float - # 2. int -> int - # 3. float -> int - # 4. int -> float - if in_dtype_float == out_dtype_float: - scale_val, offset = 1.0, 0.0 - elif in_dtype_float and not out_dtype_float: - scale_val, offset = float(_max_value(dtype)), 0.0 - else: - scale_val, offset = 1.0 / float(_max_value(dtype_in)), 0.0 + if scale: + in_dtype_float = dtype_in.is_floating_point + out_dtype_float = dtype.is_floating_point + + if in_dtype_float == out_dtype_float: + scale_val, offset = 1.0, 0.0 + elif in_dtype_float and not out_dtype_float: + scale_val, offset = float(_max_value(dtype)), 0.0 + else: + scale_val, offset = 1.0 / float(_max_value(dtype_in)), 0.0 return cvcuda.convertto( inpt, From 736a2e65e9e5dacc5311811082b32ef6dfbaadad Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Wed, 26 Nov 2025 09:59:51 -0800 Subject: [PATCH 14/38] add int -> int scaling setup for cvcuda, use bit diff for scale --- test/test_transforms_v2.py | 2 -- torchvision/transforms/v2/functional/_misc.py | 11 +++++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f94dc7c6008..deaf5947b2b 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2748,8 +2748,6 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_ pytest.xfail("float to int64 conversion is not supported") if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda": pytest.xfail("uint8 to uint16 conversion is not supported on cuda") - if input_dtype == torch.uint8 and output_dtype == torch.uint16 and scale and make_input == make_image_cvcuda: - pytest.xfail("uint8 to uint16 conversion with scale is not supported in F._misc._to_dtype_cvcuda") input = make_input(dtype=input_dtype, device=device) out = F.to_dtype(input, dtype=output_dtype, scale=scale) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index bb5fac190a6..1cb1ffa564e 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -388,7 +388,9 @@ def _to_dtype_cvcuda( 4. int -> float If scale is True, the values will be scaled to the new dtype. If scale is False, the values will not be scaled. - The scale values for float -> float and int -> int are 1.0 and 0.0 respectively. + The scale values for float -> float are 1.0 and 0.0 respectively. + The scale values for int -> int are 2^(bit_diff) of the new dtype. + Where bit_diff is the difference in the number of bits of the new dtype and the input dtype. The scale values for float -> int and int -> float are the maximum value of the new dtype. Returns: @@ -405,8 +407,13 @@ def _to_dtype_cvcuda( in_dtype_float = dtype_in.is_floating_point out_dtype_float = dtype.is_floating_point - if in_dtype_float == out_dtype_float: + if in_dtype_float and out_dtype_float: scale_val, offset = 1.0, 0.0 + elif not in_dtype_float and not out_dtype_float: + in_bits = torch.iinfo(dtype_in).bits + out_bits = torch.iinfo(dtype).bits + scale_val = float(2 ** (out_bits - in_bits)) + offset = 0.0 elif in_dtype_float and not out_dtype_float: scale_val, offset = float(_max_value(dtype)), 0.0 else: From 7a231b1594c58a2001aac1a5ad9ac3ef4f0b129e Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Wed, 26 Nov 2025 10:04:42 -0800 Subject: [PATCH 15/38] further simplify todtype test --- test/test_transforms_v2.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index deaf5947b2b..aa920934bc3 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2753,12 +2753,10 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_ out = F.to_dtype(input, dtype=output_dtype, scale=scale) if make_input == make_image_cvcuda: - expected = self.reference_convert_dtype_image_tensor( - F.cvcuda_to_tensor(input), dtype=output_dtype, scale=scale - ) + input = F.cvcuda_to_tensor(input) out = F.cvcuda_to_tensor(out) - else: - expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale) + + expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale) atol = self._get_dtype_conversion_atol(input_dtype, output_dtype, scale) torch.testing.assert_close(out, expected, rtol=0, atol=atol) From d3e45733887f10e8997c177fc88d24c41f831e61 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 17:50:41 -0800 Subject: [PATCH 16/38] update todtype based on PR reviews --- torchvision/transforms/v2/_misc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 305149c87b1..0676ccb5fdb 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -17,6 +17,7 @@ get_bounding_boxes, get_keypoints, has_any, + is_cvcuda_tensor, is_pure_tensor, ) @@ -267,7 +268,7 @@ class ToDtype(Transform): Default: ``False``. """ - _transformed_types = (torch.Tensor,) + _transformed_types = (torch.Tensor, is_cvcuda_tensor) def __init__( self, dtype: Union[torch.dtype, dict[Union[type, str], Optional[torch.dtype]]], scale: bool = False From ec93ba33746e1908ca2cd259d7440b4be059623e Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 17:53:48 -0800 Subject: [PATCH 17/38] cleanup commnet, variable names --- test/test_transforms_v2.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index aa920934bc3..8220cfcb953 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2670,10 +2670,10 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale): @pytest.mark.parametrize("scale", (True, False)) @pytest.mark.parametrize("as_dict", (True, False)) def test_transform(self, make_input, input_dtype, output_dtype, device, scale, as_dict): - input = make_input(dtype=input_dtype, device=device) + inpt = make_input(dtype=input_dtype, device=device) if as_dict: - output_dtype = {type(input): output_dtype} - check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input, check_sample_input=not as_dict) + output_dtype = {type(inpt): output_dtype} + check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), inpt, check_sample_input=not as_dict) def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False): input_dtype = image.dtype @@ -2749,14 +2749,14 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_ if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda": pytest.xfail("uint8 to uint16 conversion is not supported on cuda") - input = make_input(dtype=input_dtype, device=device) - out = F.to_dtype(input, dtype=output_dtype, scale=scale) + inpt = make_input(dtype=input_dtype, device=device) + out = F.to_dtype(inpt, dtype=output_dtype, scale=scale) if make_input == make_image_cvcuda: - input = F.cvcuda_to_tensor(input) + inpt = F.cvcuda_to_tensor(inpt) out = F.cvcuda_to_tensor(out) - expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale) + expected = self.reference_convert_dtype_image_tensor(inpt, dtype=output_dtype, scale=scale) atol = self._get_dtype_conversion_atol(input_dtype, output_dtype, scale) torch.testing.assert_close(out, expected, rtol=0, atol=atol) From 89122dbcbbad8bd7ef88724e17332b68ca5b96fe Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 10:47:01 -0800 Subject: [PATCH 18/38] update to_dtype_cvcuda name --- torchvision/transforms/v2/functional/_misc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 1cb1ffa564e..4562fb6fb3c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -369,7 +369,7 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo _cvcuda_to_torch_dtypes = {v: k for k, v in _torch_to_cvcuda_dtypes.items()} -def _to_dtype_cvcuda( +def _to_dtype_image_cvcuda( inpt: "cvcuda.Tensor", dtype: torch.dtype, scale: bool = False, @@ -428,7 +428,7 @@ def _to_dtype_cvcuda( if CVCUDA_AVAILABLE: - _register_kernel_internal(to_dtype, _import_cvcuda().Tensor)(_to_dtype_cvcuda) + _register_kernel_internal(to_dtype, _import_cvcuda().Tensor)(_to_dtype_image_cvcuda) def sanitize_bounding_boxes( From 1b0d29569b79eb862f821c42344b233d6cf05c91 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 10:52:05 -0800 Subject: [PATCH 19/38] update to standards from flip PR --- test/test_transforms_v2.py | 1 - torchvision/transforms/v2/_misc.py | 8 ++++++-- torchvision/transforms/v2/_utils.py | 8 ++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 8220cfcb953..e826d739f62 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -25,7 +25,6 @@ assert_equal, cache, cpu_and_cuda, - cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 0676ccb5fdb..de24eb159a0 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -9,6 +9,7 @@ from torchvision import transforms as _transforms, tv_tensors from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor from ._utils import ( _parse_labels_getter, @@ -17,11 +18,13 @@ get_bounding_boxes, get_keypoints, has_any, - is_cvcuda_tensor, is_pure_tensor, ) +CVCUDA_AVAILABLE = _is_cvcuda_available() + + # TODO: do we want/need to expose this? class Identity(Transform): def transform(self, inpt: Any, params: dict[str, Any]) -> Any: @@ -268,7 +271,8 @@ class ToDtype(Transform): Default: ``False``. """ - _transformed_types = (torch.Tensor, is_cvcuda_tensor) + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) def __init__( self, dtype: Union[torch.dtype, dict[Union[type, str], Optional[torch.dtype]]], scale: bool = False diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 3fc33ce5964..e803aa49c60 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,8 +15,8 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,7 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, - is_cvcuda_tensor, + _is_cvcuda_tensor, ), ) } From 009f925e64577d39d54cd611174ee0800e6ad61d Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 10:53:16 -0800 Subject: [PATCH 20/38] remove cvcuda updates to augment --- torchvision/transforms/v2/functional/_augment.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 7ce5bdc7b7e..a904d8d7cbd 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,5 +1,4 @@ import io -from typing import TYPE_CHECKING import PIL.Image @@ -9,15 +8,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal - - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal def erase( From 41af724d717acad57dbcc61bbe068f5d9ea61368 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 10:53:55 -0800 Subject: [PATCH 21/38] remove cvcuda refs from color --- torchvision/transforms/v2/functional/_color.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 5be9c62902a..be254c0d63a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,5 +1,3 @@ -from typing import TYPE_CHECKING - import PIL.Image import torch from torch.nn.functional import conv2d @@ -11,15 +9,7 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal - - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: From d12e4df1a85a8820c9dd09bf8b68b211dde7bdbd Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 13:18:47 -0800 Subject: [PATCH 22/38] refactor dtype converters to be in utils --- torchvision/transforms/v2/functional/_misc.py | 36 +++++----------- .../transforms/v2/functional/_utils.py | 41 +++++++++++++++++++ 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 4562fb6fb3c..b58df257f30 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -13,7 +13,15 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor +from ._utils import ( + _get_cvcuda_type_from_torch_dtype, + _get_kernel, + _get_torch_dtype_from_cvcuda_type, + _import_cvcuda, + _is_cvcuda_available, + _register_kernel_internal, + is_pure_tensor, +) CVCUDA_AVAILABLE = _is_cvcuda_available() @@ -347,28 +355,6 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo return inpt.to(dtype) -# cvcuda is only used if it is installed, so we can simply define empty mappings -_torch_to_cvcuda_dtypes = {} -_cvcuda_to_torch_dtypes = {} -if CVCUDA_AVAILABLE: - # put the entire conversion set here - # only a subset are used for torchvision - _torch_to_cvcuda_dtypes = { - torch.uint8: cvcuda.Type.U8, - torch.uint16: cvcuda.Type.U16, - torch.uint32: cvcuda.Type.U32, - torch.uint64: cvcuda.Type.U64, - torch.int8: cvcuda.Type.S8, - torch.int16: cvcuda.Type.S16, - torch.int32: cvcuda.Type.S32, - torch.int64: cvcuda.Type.S64, - torch.float32: cvcuda.Type.F32, - torch.float64: cvcuda.Type.F64, - } - # create reverse mapping - _cvcuda_to_torch_dtypes = {v: k for k, v in _torch_to_cvcuda_dtypes.items()} - - def _to_dtype_image_cvcuda( inpt: "cvcuda.Tensor", dtype: torch.dtype, @@ -399,8 +385,8 @@ def _to_dtype_image_cvcuda( """ cvcuda = _import_cvcuda() - dtype_in = _cvcuda_to_torch_dtypes[inpt.dtype] - cvc_dtype = _torch_to_cvcuda_dtypes[dtype] + dtype_in = _get_torch_dtype_from_cvcuda_type(inpt.dtype) + cvc_dtype = _get_cvcuda_type_from_torch_dtype(dtype) scale_val, offset = 1.0, 0.0 if scale: diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 11480b30ef9..f99f65bc701 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -177,3 +177,44 @@ def _is_cvcuda_tensor(inpt: Any) -> bool: return isinstance(inpt, cvcuda.Tensor) except ImportError: return False + + +# cvcuda is only used if it is installed, so we can simply define empty mappings +_torch_to_cvcuda_dtypes: dict[torch.dtype, "cvcuda.Type"] = {} +_cvcuda_to_torch_dtypes: dict["cvcuda.Type", torch.dtype] = {} + + +def _populate_cvcuda_dtype_tables(): + cvcuda = _import_cvcuda() + + global _torch_to_cvcuda_dtypes + global _cvcuda_to_torch_dtypes + + # put the entire conversion set here + # only a subset are used for torchvision + _torch_to_cvcuda_dtypes = { + torch.uint8: cvcuda.Type.U8, + torch.uint16: cvcuda.Type.U16, + torch.uint32: cvcuda.Type.U32, + torch.uint64: cvcuda.Type.U64, + torch.int8: cvcuda.Type.S8, + torch.int16: cvcuda.Type.S16, + torch.int32: cvcuda.Type.S32, + torch.int64: cvcuda.Type.S64, + torch.float32: cvcuda.Type.F32, + torch.float64: cvcuda.Type.F64, + } + # create reverse mapping + _cvcuda_to_torch_dtypes = {v: k for k, v in _torch_to_cvcuda_dtypes.items()} + + +def _get_cvcuda_type_from_torch_dtype(dtype: torch.dtype) -> "cvcuda.Type": + if len(_torch_to_cvcuda_dtypes.keys()) == 0: + _populate_cvcuda_dtype_tables() + return _torch_to_cvcuda_dtypes[dtype] + + +def _get_torch_dtype_from_cvcuda_type(dtype: "cvcuda.Type") -> torch.dtype: + if len(_cvcuda_to_torch_dtypes.keys()) == 0: + _populate_cvcuda_dtype_tables() + return _cvcuda_to_torch_dtypes[dtype] From c198cf083e56acb4b46b317176cdfc2f56c83c08 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 13:21:29 -0800 Subject: [PATCH 23/38] add type checking for cvcuda --- torchvision/transforms/v2/functional/_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index f99f65bc701..29fc7e516e9 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,10 +1,13 @@ import functools from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torchvision import tv_tensors +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] + _FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillTypeJIT = Optional[list[float]] From 18df67fb142322983cd5167fd79d216bc082788b Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 14:05:16 -0800 Subject: [PATCH 24/38] provide better error for todtype --- torchvision/transforms/v2/functional/_utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 29fc7e516e9..e0b8800a956 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -212,12 +212,18 @@ def _populate_cvcuda_dtype_tables(): def _get_cvcuda_type_from_torch_dtype(dtype: torch.dtype) -> "cvcuda.Type": - if len(_torch_to_cvcuda_dtypes.keys()) == 0: + if len(_torch_to_cvcuda_dtypes) == 0: _populate_cvcuda_dtype_tables() - return _torch_to_cvcuda_dtypes[dtype] + cv_type = _torch_to_cvcuda_dtypes.get(dtype) + if cv_type is None: + raise ValueError(f"No CV-CUDA type found for torch dtype {dtype}") + return cv_type def _get_torch_dtype_from_cvcuda_type(dtype: "cvcuda.Type") -> torch.dtype: - if len(_cvcuda_to_torch_dtypes.keys()) == 0: + if len(_cvcuda_to_torch_dtypes) == 0: _populate_cvcuda_dtype_tables() - return _cvcuda_to_torch_dtypes[dtype] + torch_dtype = _cvcuda_to_torch_dtypes.get(dtype) + if torch_dtype is None: + raise ValueError(f"No torch dtype found for CV-CUDA type {dtype}") + return torch_dtype From c5a2a5a7ece6d25a5f9becc4cf4d13ddde984ee4 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 18:10:51 -0800 Subject: [PATCH 25/38] refactor to simplify setup for dtype conversions --- .../transforms/v2/functional/_utils.py | 48 +++++++++---------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index e0b8800a956..eff559ce17e 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -187,33 +187,19 @@ def _is_cvcuda_tensor(inpt: Any) -> bool: _cvcuda_to_torch_dtypes: dict["cvcuda.Type", torch.dtype] = {} -def _populate_cvcuda_dtype_tables(): - cvcuda = _import_cvcuda() - - global _torch_to_cvcuda_dtypes - global _cvcuda_to_torch_dtypes - - # put the entire conversion set here - # only a subset are used for torchvision - _torch_to_cvcuda_dtypes = { - torch.uint8: cvcuda.Type.U8, - torch.uint16: cvcuda.Type.U16, - torch.uint32: cvcuda.Type.U32, - torch.uint64: cvcuda.Type.U64, - torch.int8: cvcuda.Type.S8, - torch.int16: cvcuda.Type.S16, - torch.int32: cvcuda.Type.S32, - torch.int64: cvcuda.Type.S64, - torch.float32: cvcuda.Type.F32, - torch.float64: cvcuda.Type.F64, - } - # create reverse mapping - _cvcuda_to_torch_dtypes = {v: k for k, v in _torch_to_cvcuda_dtypes.items()} - - def _get_cvcuda_type_from_torch_dtype(dtype: torch.dtype) -> "cvcuda.Type": if len(_torch_to_cvcuda_dtypes) == 0: - _populate_cvcuda_dtype_tables() + cvcuda = _import_cvcuda() + _torch_to_cvcuda_dtypes[torch.uint8] = cvcuda.Type.U8 + _torch_to_cvcuda_dtypes[torch.uint16] = cvcuda.Type.U16 + _torch_to_cvcuda_dtypes[torch.uint32] = cvcuda.Type.U32 + _torch_to_cvcuda_dtypes[torch.uint64] = cvcuda.Type.U64 + _torch_to_cvcuda_dtypes[torch.int8] = cvcuda.Type.S8 + _torch_to_cvcuda_dtypes[torch.int16] = cvcuda.Type.S16 + _torch_to_cvcuda_dtypes[torch.int32] = cvcuda.Type.S32 + _torch_to_cvcuda_dtypes[torch.int64] = cvcuda.Type.S64 + _torch_to_cvcuda_dtypes[torch.float32] = cvcuda.Type.F32 + _torch_to_cvcuda_dtypes[torch.float64] = cvcuda.Type.F64 cv_type = _torch_to_cvcuda_dtypes.get(dtype) if cv_type is None: raise ValueError(f"No CV-CUDA type found for torch dtype {dtype}") @@ -222,7 +208,17 @@ def _get_cvcuda_type_from_torch_dtype(dtype: torch.dtype) -> "cvcuda.Type": def _get_torch_dtype_from_cvcuda_type(dtype: "cvcuda.Type") -> torch.dtype: if len(_cvcuda_to_torch_dtypes) == 0: - _populate_cvcuda_dtype_tables() + cvcuda = _import_cvcuda() + _cvcuda_to_torch_dtypes[cvcuda.Type.U8] = torch.uint8 + _cvcuda_to_torch_dtypes[cvcuda.Type.U16] = torch.uint16 + _cvcuda_to_torch_dtypes[cvcuda.Type.U32] = torch.uint32 + _cvcuda_to_torch_dtypes[cvcuda.Type.U64] = torch.uint64 + _cvcuda_to_torch_dtypes[cvcuda.Type.S8] = torch.int8 + _cvcuda_to_torch_dtypes[cvcuda.Type.S16] = torch.int16 + _cvcuda_to_torch_dtypes[cvcuda.Type.S32] = torch.int32 + _cvcuda_to_torch_dtypes[cvcuda.Type.S64] = torch.int64 + _cvcuda_to_torch_dtypes[cvcuda.Type.F32] = torch.float32 + _cvcuda_to_torch_dtypes[cvcuda.Type.F64] = torch.float64 torch_dtype = _cvcuda_to_torch_dtypes.get(dtype) if torch_dtype is None: raise ValueError(f"No torch dtype found for CV-CUDA type {dtype}") From 7f41c9572724166c437d955fa990ba1cc01138ef Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Fri, 5 Dec 2025 15:13:10 -0800 Subject: [PATCH 26/38] fix: not testing transform class correctness in ToDtype, resolved --- test/test_transforms_v2.py | 5 +++-- torchvision/transforms/v2/_misc.py | 12 ++++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index e826d739f62..71910daee9b 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2742,14 +2742,15 @@ def _get_dtype_conversion_atol(self, input_dtype, output_dtype, scale): ), ], ) - def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_input): + @pytest.mark.parametrize("fn", [F.to_dtype, transform_cls_to_functional(transforms.ToDtype)]) + def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_input, fn): if input_dtype.is_floating_point and output_dtype == torch.int64: pytest.xfail("float to int64 conversion is not supported") if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda": pytest.xfail("uint8 to uint16 conversion is not supported on cuda") inpt = make_input(dtype=input_dtype, device=device) - out = F.to_dtype(inpt, dtype=output_dtype, scale=scale) + out = fn(inpt, dtype=output_dtype, scale=scale) if make_input == make_image_cvcuda: inpt = F.cvcuda_to_tensor(inpt) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index de24eb159a0..93fdbca02ac 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -299,7 +299,11 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any: if isinstance(self.dtype, torch.dtype): # For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype # is a simple torch.dtype - if not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)): + if ( + not is_pure_tensor(inpt) + and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) + and (CVCUDA_AVAILABLE and not _is_cvcuda_tensor(inpt)) + ): return inpt dtype: Optional[torch.dtype] = self.dtype @@ -316,7 +320,11 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any: 'e.g. dtype={tv_tensors.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.' ) - supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) + supports_scaling = ( + is_pure_tensor(inpt) + or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) + or (CVCUDA_AVAILABLE and _is_cvcuda_tensor(inpt)) + ) if dtype is None: if self.scale and supports_scaling: warnings.warn( From 9b41552cf3e44a56079dc1e4dd93bad42138c946 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 8 Dec 2025 13:53:41 -0800 Subject: [PATCH 27/38] preserve previous torchvision test behavior for non cvcuda inputs --- test/test_transforms_v2.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 71910daee9b..af0a70ec24d 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2758,8 +2758,14 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_ expected = self.reference_convert_dtype_image_tensor(inpt, dtype=output_dtype, scale=scale) - atol = self._get_dtype_conversion_atol(input_dtype, output_dtype, scale) - torch.testing.assert_close(out, expected, rtol=0, atol=atol) + if make_input is make_image_cvcuda: + atol = self._get_dtype_conversion_atol(input_dtype, output_dtype, scale) + assert_close(out, expected, rtol=0, atol=atol) + else: + if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale: + torch.testing.assert_close(out, expected, atol=1, rtol=0) + else: + torch.testing.assert_close(out, expected) def was_scaled(self, inpt): # this assumes the target dtype is float From b9c378bbff0fde57b7aad9be7f6cc811729acefd Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 8 Dec 2025 13:59:38 -0800 Subject: [PATCH 28/38] further simplify branching flow of testtodtype image correctness --- test/test_transforms_v2.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index af0a70ec24d..91e2f1b49ac 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2707,7 +2707,7 @@ def fn(value): return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device) - def _get_dtype_conversion_atol(self, input_dtype, output_dtype, scale): + def _get_dtype_conversion_atol_cvcuda(self, input_dtype, output_dtype, scale): is_uint16_to_uint8 = input_dtype == torch.uint16 and output_dtype == torch.uint8 is_uint8_to_uint16 = input_dtype == torch.uint8 and output_dtype == torch.uint16 changes_type_class = output_dtype.is_floating_point != input_dtype.is_floating_point @@ -2758,14 +2758,14 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_ expected = self.reference_convert_dtype_image_tensor(inpt, dtype=output_dtype, scale=scale) + atol, rtol = None, None if make_input is make_image_cvcuda: - atol = self._get_dtype_conversion_atol(input_dtype, output_dtype, scale) - assert_close(out, expected, rtol=0, atol=atol) - else: - if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale: - torch.testing.assert_close(out, expected, atol=1, rtol=0) - else: - torch.testing.assert_close(out, expected) + atol = self._get_dtype_conversion_atol_cvcuda(input_dtype, output_dtype, scale) + rtol = 0 + elif input_dtype.is_floating_point and not output_dtype.is_floating_point and scale: + atol, rtol = 1, 0 + + torch.testing.assert_close(out, expected, atol=atol, rtol=rtol) def was_scaled(self, inpt): # this assumes the target dtype is float From 1781244d082497a46b5349c277445102f8252421 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 8 Dec 2025 14:09:35 -0800 Subject: [PATCH 29/38] add functional signature tests, fix bug in type check in todtype transform class --- test/test_transforms_v2.py | 19 ++++++++++++++++++- torchvision/transforms/v2/_misc.py | 2 +- torchvision/transforms/v2/functional/_misc.py | 2 +- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 91e2f1b49ac..4258768def5 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2627,6 +2627,23 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca scale=scale, ) + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.to_dtype_image, torch.Tensor), + (F.to_dtype_video, tv_tensors.Video), + pytest.param( + F._misc._to_dtype_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), + ], + ) + def test_functional_signature(self, kernel, input_type): + if kernel is F._misc._to_dtype_image_cvcuda: + input_type = _import_cvcuda().Tensor + check_functional_kernel_signature_match(F.to_dtype, kernel=kernel, input_type=input_type) + @pytest.mark.parametrize( "make_input", [ @@ -2752,7 +2769,7 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_ inpt = make_input(dtype=input_dtype, device=device) out = fn(inpt, dtype=output_dtype, scale=scale) - if make_input == make_image_cvcuda: + if make_input is make_image_cvcuda: inpt = F.cvcuda_to_tensor(inpt) out = F.cvcuda_to_tensor(out) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 93fdbca02ac..fc310ee84a2 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -302,7 +302,7 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any: if ( not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) - and (CVCUDA_AVAILABLE and not _is_cvcuda_tensor(inpt)) + and not (CVCUDA_AVAILABLE and _is_cvcuda_tensor(inpt)) ): return inpt diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index b58df257f30..ce8df8983f1 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -357,7 +357,7 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo def _to_dtype_image_cvcuda( inpt: "cvcuda.Tensor", - dtype: torch.dtype, + dtype: torch.dtype = torch.float, scale: bool = False, ) -> "cvcuda.Tensor": """ From 626b47a2d2709860b9b855cf5f984f02394fdbfb Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 9 Dec 2025 08:50:27 -0800 Subject: [PATCH 30/38] add consolidated cvcuda test markers --- test/test_transforms_v2.py | 90 ++++++++++++-------------------------- 1 file changed, 29 insertions(+), 61 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 4258768def5..c67a8ff1500 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -60,9 +60,17 @@ ) -CVCUDA_AVAILABLE = _is_cvcuda_available() -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() +CV_CUDA_TEST = ( + pytest.mark.skipif(not _is_cvcuda_available(), reason="CVCUDA is not available"), + pytest.mark.needs_cuda, +) + + +def CV_CUDA_TEST_CLASS(cls): + for mark in CV_CUDA_TEST: + cls = mark(cls) + return cls + # turns all warnings into errors for this module pytestmark = [pytest.mark.filterwarnings("error")] @@ -1240,10 +1248,7 @@ def test_kernel_video(self): make_image_tensor, make_image_pil, make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1259,11 +1264,7 @@ def test_functional(self, make_input): (F.horizontal_flip_image, torch.Tensor), (F._geometry._horizontal_flip_image_pil, PIL.Image.Image), (F.horizontal_flip_image, tv_tensors.Image), - pytest.param( - F._geometry._horizontal_flip_image_cvcuda, - None, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(F._geometry._horizontal_flip_image_cvcuda, None, marks=CV_CUDA_TEST), (F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes), (F.horizontal_flip_mask, tv_tensors.Mask), (F.horizontal_flip_video, tv_tensors.Video), @@ -1281,10 +1282,7 @@ def test_functional_signature(self, kernel, input_type): make_image_tensor, make_image_pil, make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1302,10 +1300,7 @@ def test_transform(self, make_input, device): "make_input", [ make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), ], ) def test_image_correctness(self, fn, make_input): @@ -1370,10 +1365,7 @@ def test_keypoints_correctness(self, fn): make_image_tensor, make_image_pil, make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1882,10 +1874,7 @@ def test_kernel_video(self): make_image_tensor, make_image_pil, make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1901,11 +1890,7 @@ def test_functional(self, make_input): (F.vertical_flip_image, torch.Tensor), (F._geometry._vertical_flip_image_pil, PIL.Image.Image), (F.vertical_flip_image, tv_tensors.Image), - pytest.param( - F._geometry._vertical_flip_image_cvcuda, - None, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(F._geometry._vertical_flip_image_cvcuda, None, marks=CV_CUDA_TEST), (F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes), (F.vertical_flip_mask, tv_tensors.Mask), (F.vertical_flip_video, tv_tensors.Video), @@ -1923,10 +1908,7 @@ def test_functional_signature(self, kernel, input_type): make_image_tensor, make_image_pil, make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1942,10 +1924,7 @@ def test_transform(self, make_input, device): "make_input", [ make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), ], ) def test_image_correctness(self, fn, make_input): @@ -2006,10 +1985,7 @@ def test_keypoints_correctness(self, fn): make_image_tensor, make_image_pil, make_image, - pytest.param( - make_image_cvcuda, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), make_bounding_boxes, make_segmentation_mask, make_video, @@ -2635,7 +2611,7 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca pytest.param( F._misc._to_dtype_image_cvcuda, None, - marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + marks=CV_CUDA_TEST, ), ], ) @@ -2650,9 +2626,7 @@ def test_functional_signature(self, kernel, input_type): make_image_tensor, make_image, make_video, - pytest.param( - make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), ], ) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) @@ -2675,9 +2649,7 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale): make_bounding_boxes, make_segmentation_mask, make_video, - pytest.param( - make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), ], ) @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) @@ -2754,9 +2726,7 @@ def _get_dtype_conversion_atol_cvcuda(self, input_dtype, output_dtype, scale): "make_input", [ make_image, - pytest.param( - make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") - ), + pytest.param(make_image_cvcuda, marks=CV_CUDA_TEST), ], ) @pytest.mark.parametrize("fn", [F.to_dtype, transform_cls_to_functional(transforms.ToDtype)]) @@ -6870,8 +6840,7 @@ def test_functional_error(self): F.pil_to_tensor(object()) -@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") -@needs_cuda +@CV_CUDA_TEST_CLASS class TestToCVCUDATensor: @pytest.mark.parametrize("image_type", (torch.Tensor, tv_tensors.Image)) @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) @@ -6889,7 +6858,7 @@ def test_functional_and_transform(self, image_type, dtype, device, color_space, assert is_pure_tensor(image) output = fn(image) - assert isinstance(output, cvcuda.Tensor) + assert isinstance(output, _import_cvcuda().Tensor) assert F.get_size(output) == F.get_size(image) assert output is not None @@ -6932,9 +6901,8 @@ def test_round_trip(self, dtype, device, color_space, batch_size): assert result_tensor.shape[0] == batch_size -@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") -@needs_cuda -class TestCVDUDAToTensor: +@CV_CUDA_TEST_CLASS +class TestCVCUDAToTensor: @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) From e8540ba56d966a73426892ed58412d86c6aa7871 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 9 Dec 2025 09:18:27 -0800 Subject: [PATCH 31/38] finalize consolidated cvcuda skip behavior --- test/test_transforms_v2.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index c67a8ff1500..8a8fd249145 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -60,16 +60,10 @@ ) -CV_CUDA_TEST = ( +CV_CUDA_TEST = [ pytest.mark.skipif(not _is_cvcuda_available(), reason="CVCUDA is not available"), pytest.mark.needs_cuda, -) - - -def CV_CUDA_TEST_CLASS(cls): - for mark in CV_CUDA_TEST: - cls = mark(cls) - return cls +] # turns all warnings into errors for this module @@ -6840,8 +6834,9 @@ def test_functional_error(self): F.pil_to_tensor(object()) -@CV_CUDA_TEST_CLASS class TestToCVCUDATensor: + pytestmark = CV_CUDA_TEST + @pytest.mark.parametrize("image_type", (torch.Tensor, tv_tensors.Image)) @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -6901,8 +6896,9 @@ def test_round_trip(self, dtype, device, color_space, batch_size): assert result_tensor.shape[0] == batch_size -@CV_CUDA_TEST_CLASS class TestCVCUDAToTensor: + pytestmark = CV_CUDA_TEST + @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) From 5aa4b3d105d1ddd7e2a3afda7bd0f57c1e5eb9b3 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 9 Dec 2025 10:53:40 -0800 Subject: [PATCH 32/38] revert var name change back to input --- test/test_transforms_v2.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 8a8fd249145..f6548e89336 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2652,10 +2652,10 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale): @pytest.mark.parametrize("scale", (True, False)) @pytest.mark.parametrize("as_dict", (True, False)) def test_transform(self, make_input, input_dtype, output_dtype, device, scale, as_dict): - inpt = make_input(dtype=input_dtype, device=device) + input = make_input(dtype=input_dtype, device=device) if as_dict: - output_dtype = {type(inpt): output_dtype} - check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), inpt, check_sample_input=not as_dict) + output_dtype = {type(input): output_dtype} + check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input, check_sample_input=not as_dict) def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False): input_dtype = image.dtype @@ -2730,14 +2730,14 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_ if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda": pytest.xfail("uint8 to uint16 conversion is not supported on cuda") - inpt = make_input(dtype=input_dtype, device=device) - out = fn(inpt, dtype=output_dtype, scale=scale) + input = make_input(dtype=input_dtype, device=device) + out = fn(input, dtype=output_dtype, scale=scale) if make_input is make_image_cvcuda: - inpt = F.cvcuda_to_tensor(inpt) + input = F.cvcuda_to_tensor(input) out = F.cvcuda_to_tensor(out) - expected = self.reference_convert_dtype_image_tensor(inpt, dtype=output_dtype, scale=scale) + expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale) atol, rtol = None, None if make_input is make_image_cvcuda: From 7cbf30e84553e0a48e12e97fcb8b2db1cff28061 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Fri, 12 Dec 2025 11:47:00 -0800 Subject: [PATCH 33/38] drop the dimensions and num channels variants for cvcuda --- torchvision/transforms/v2/functional/_meta.py | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index af03ad018d4..6b8f19f12f4 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,16 +51,6 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) -def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: - # CV-CUDA tensor is always in NHWC layout - # get_dimensions is CHW - return [image.shape[3], image.shape[1], image.shape[2]] - - -if CVCUDA_AVAILABLE: - _register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda) - - def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image(inpt) @@ -97,16 +87,6 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels -def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int: - # CV-CUDA tensor is always in NHWC layout - # get_num_channels is C - return image.shape[3] - - -if CVCUDA_AVAILABLE: - _register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda) - - def get_size(inpt: torch.Tensor) -> list[int]: if torch.jit.is_scripting(): return get_size_image(inpt) @@ -145,7 +125,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: if CVCUDA_AVAILABLE: - _register_kernel_internal(get_size, _import_cvcuda().Tensor)(get_size_image_cvcuda) + _get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False) From 93bf6747c3ee94654b819185157d497fd13d45a5 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Fri, 12 Dec 2025 11:51:09 -0800 Subject: [PATCH 34/38] drop _is_cvcuda_tensor from _utils query_size query_chw unused in this pr --- torchvision/transforms/v2/_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index e803aa49c60..bb6051b4e61 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -16,7 +16,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,7 +207,6 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, - _is_cvcuda_tensor, ), ) } From ecf3c584b6f7de0234a59f121c4c3e2aebaaee80 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Fri, 12 Dec 2025 11:53:23 -0800 Subject: [PATCH 35/38] simplify ToDtype class --- torchvision/transforms/v2/_misc.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index fc310ee84a2..26749c855a4 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -9,7 +9,7 @@ from torchvision import transforms as _transforms, tv_tensors from torchvision.transforms.v2 import functional as F, Transform -from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor +from torchvision.transforms.v2.functional._utils import _is_cvcuda_tensor from ._utils import ( _parse_labels_getter, @@ -22,9 +22,6 @@ ) -CVCUDA_AVAILABLE = _is_cvcuda_available() - - # TODO: do we want/need to expose this? class Identity(Transform): def transform(self, inpt: Any, params: dict[str, Any]) -> Any: @@ -271,8 +268,7 @@ class ToDtype(Transform): Default: ``False``. """ - if CVCUDA_AVAILABLE: - _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) def __init__( self, dtype: Union[torch.dtype, dict[Union[type, str], Optional[torch.dtype]]], scale: bool = False @@ -302,7 +298,7 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any: if ( not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) - and not (CVCUDA_AVAILABLE and _is_cvcuda_tensor(inpt)) + and not _is_cvcuda_tensor(inpt) ): return inpt @@ -321,9 +317,7 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any: ) supports_scaling = ( - is_pure_tensor(inpt) - or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) - or (CVCUDA_AVAILABLE and _is_cvcuda_tensor(inpt)) + is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or _is_cvcuda_tensor(inpt) ) if dtype is None: if self.scale and supports_scaling: From 8cd76dc76b3e426d46abd96a4053b435f783e2a2 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Fri, 12 Dec 2025 12:02:20 -0800 Subject: [PATCH 36/38] refactor to move the dtype tables to _misc adjacent with to_dtype_image_cvcuda --- torchvision/transforms/v2/functional/_misc.py | 37 +++++++++----- .../transforms/v2/functional/_utils.py | 48 +------------------ 2 files changed, 27 insertions(+), 58 deletions(-) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index ce8df8983f1..c8b53ea8267 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -13,15 +13,7 @@ from ._meta import _convert_bounding_box_format -from ._utils import ( - _get_cvcuda_type_from_torch_dtype, - _get_kernel, - _get_torch_dtype_from_cvcuda_type, - _import_cvcuda, - _is_cvcuda_available, - _register_kernel_internal, - is_pure_tensor, -) +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor CVCUDA_AVAILABLE = _is_cvcuda_available() @@ -355,6 +347,11 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo return inpt.to(dtype) +# cvcuda is only used if it is installed, so we can simply define empty mappings +_torch_to_cvcuda_dtypes: dict[torch.dtype, "cvcuda.Type"] = {} +_cvcuda_to_torch_dtypes: dict["cvcuda.Type", torch.dtype] = {} + + def _to_dtype_image_cvcuda( inpt: "cvcuda.Tensor", dtype: torch.dtype = torch.float, @@ -385,8 +382,26 @@ def _to_dtype_image_cvcuda( """ cvcuda = _import_cvcuda() - dtype_in = _get_torch_dtype_from_cvcuda_type(inpt.dtype) - cvc_dtype = _get_cvcuda_type_from_torch_dtype(dtype) + if not _torch_to_cvcuda_dtypes: + _torch_to_cvcuda_dtypes[torch.uint8] = cvcuda.Type.U8 + _torch_to_cvcuda_dtypes[torch.uint16] = cvcuda.Type.U16 + _torch_to_cvcuda_dtypes[torch.uint32] = cvcuda.Type.U32 + _torch_to_cvcuda_dtypes[torch.uint64] = cvcuda.Type.U64 + _torch_to_cvcuda_dtypes[torch.int8] = cvcuda.Type.S8 + _torch_to_cvcuda_dtypes[torch.int16] = cvcuda.Type.S16 + _torch_to_cvcuda_dtypes[torch.int32] = cvcuda.Type.S32 + _torch_to_cvcuda_dtypes[torch.int64] = cvcuda.Type.S64 + _torch_to_cvcuda_dtypes[torch.float32] = cvcuda.Type.F32 + _torch_to_cvcuda_dtypes[torch.float64] = cvcuda.Type.F64 + + if not _cvcuda_to_torch_dtypes: + for k, v in _torch_to_cvcuda_dtypes.items(): + _cvcuda_to_torch_dtypes[v] = k + + dtype_in = _cvcuda_to_torch_dtypes.get(inpt.dtype) + cvc_dtype = _torch_to_cvcuda_dtypes.get(dtype) + if dtype_in is None or cvc_dtype is None: + raise ValueError(f"No torch or cvcuda dtype found for dtype {dtype} or {inpt.dtype}") scale_val, offset = 1.0, 0.0 if scale: diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index eff559ce17e..11480b30ef9 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,13 +1,10 @@ import functools from collections.abc import Sequence -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Optional, Union import torch from torchvision import tv_tensors -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] - _FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillTypeJIT = Optional[list[float]] @@ -180,46 +177,3 @@ def _is_cvcuda_tensor(inpt: Any) -> bool: return isinstance(inpt, cvcuda.Tensor) except ImportError: return False - - -# cvcuda is only used if it is installed, so we can simply define empty mappings -_torch_to_cvcuda_dtypes: dict[torch.dtype, "cvcuda.Type"] = {} -_cvcuda_to_torch_dtypes: dict["cvcuda.Type", torch.dtype] = {} - - -def _get_cvcuda_type_from_torch_dtype(dtype: torch.dtype) -> "cvcuda.Type": - if len(_torch_to_cvcuda_dtypes) == 0: - cvcuda = _import_cvcuda() - _torch_to_cvcuda_dtypes[torch.uint8] = cvcuda.Type.U8 - _torch_to_cvcuda_dtypes[torch.uint16] = cvcuda.Type.U16 - _torch_to_cvcuda_dtypes[torch.uint32] = cvcuda.Type.U32 - _torch_to_cvcuda_dtypes[torch.uint64] = cvcuda.Type.U64 - _torch_to_cvcuda_dtypes[torch.int8] = cvcuda.Type.S8 - _torch_to_cvcuda_dtypes[torch.int16] = cvcuda.Type.S16 - _torch_to_cvcuda_dtypes[torch.int32] = cvcuda.Type.S32 - _torch_to_cvcuda_dtypes[torch.int64] = cvcuda.Type.S64 - _torch_to_cvcuda_dtypes[torch.float32] = cvcuda.Type.F32 - _torch_to_cvcuda_dtypes[torch.float64] = cvcuda.Type.F64 - cv_type = _torch_to_cvcuda_dtypes.get(dtype) - if cv_type is None: - raise ValueError(f"No CV-CUDA type found for torch dtype {dtype}") - return cv_type - - -def _get_torch_dtype_from_cvcuda_type(dtype: "cvcuda.Type") -> torch.dtype: - if len(_cvcuda_to_torch_dtypes) == 0: - cvcuda = _import_cvcuda() - _cvcuda_to_torch_dtypes[cvcuda.Type.U8] = torch.uint8 - _cvcuda_to_torch_dtypes[cvcuda.Type.U16] = torch.uint16 - _cvcuda_to_torch_dtypes[cvcuda.Type.U32] = torch.uint32 - _cvcuda_to_torch_dtypes[cvcuda.Type.U64] = torch.uint64 - _cvcuda_to_torch_dtypes[cvcuda.Type.S8] = torch.int8 - _cvcuda_to_torch_dtypes[cvcuda.Type.S16] = torch.int16 - _cvcuda_to_torch_dtypes[cvcuda.Type.S32] = torch.int32 - _cvcuda_to_torch_dtypes[cvcuda.Type.S64] = torch.int64 - _cvcuda_to_torch_dtypes[cvcuda.Type.F32] = torch.float32 - _cvcuda_to_torch_dtypes[cvcuda.Type.F64] = torch.float64 - torch_dtype = _cvcuda_to_torch_dtypes.get(dtype) - if torch_dtype is None: - raise ValueError(f"No torch dtype found for CV-CUDA type {dtype}") - return torch_dtype From 713810f382f3e7bfe1bd71b9bdbda212601bcec6 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Fri, 12 Dec 2025 12:03:05 -0800 Subject: [PATCH 37/38] drop the evergreen cvcuda import at file level --- torchvision/transforms/v2/functional/_misc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index c8b53ea8267..9d2f8ae7701 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -19,8 +19,6 @@ if TYPE_CHECKING: import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 def normalize( From a68b7b5c1ef46ea83a080bc5b93c2312526a0ffd Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Fri, 12 Dec 2025 13:10:49 -0800 Subject: [PATCH 38/38] make atol thresholds clearer and smaller, drop uint16 to uint8 for cvcuda --- test/test_transforms_v2.py | 32 +++++++++++-------- torchvision/transforms/v2/functional/_misc.py | 12 ++++++- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f6548e89336..2c52c544fe1 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2690,23 +2690,22 @@ def fn(value): return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device) - def _get_dtype_conversion_atol_cvcuda(self, input_dtype, output_dtype, scale): - is_uint16_to_uint8 = input_dtype == torch.uint16 and output_dtype == torch.uint8 - is_uint8_to_uint16 = input_dtype == torch.uint8 and output_dtype == torch.uint16 - changes_type_class = output_dtype.is_floating_point != input_dtype.is_floating_point - + def _get_dtype_conversion_atol_cvcuda(self, input_dtype, output_dtype): in_bits = torch.iinfo(input_dtype).bits if not input_dtype.is_floating_point else None out_bits = torch.iinfo(output_dtype).bits if not output_dtype.is_floating_point else None - expands_bits = in_bits is not None and out_bits is not None and out_bits > in_bits + narrows_bits = in_bits is not None and out_bits is not None and out_bits < in_bits - if is_uint16_to_uint8: - atol = 255 - elif is_uint8_to_uint16 and not scale: - atol = 255 - elif expands_bits and not scale: + # int->int with narrowing bits, allow atol=1 for rounding diffs + if narrows_bits: atol = 1 - elif changes_type_class: + # float->int check for same diff, rounding error on float + elif input_dtype.is_floating_point and not output_dtype.is_floating_point: atol = 1 + # if generating a float value from an int, allow small rounding error + elif not input_dtype.is_floating_point and output_dtype.is_floating_point: + atol = 1e-7 + # all other cases, should be exact + # uint8 -> uint16 promotion would be here else: atol = 0 @@ -2729,6 +2728,13 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_ pytest.xfail("float to int64 conversion is not supported") if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda": pytest.xfail("uint8 to uint16 conversion is not supported on cuda") + if ( + input_dtype == torch.uint16 + and output_dtype == torch.uint8 + and not scale + and make_input is make_image_cvcuda + ): + pytest.xfail("uint16 to uint8 conversion without scale is not supported for CV-CUDA.") input = make_input(dtype=input_dtype, device=device) out = fn(input, dtype=output_dtype, scale=scale) @@ -2741,7 +2747,7 @@ def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_ atol, rtol = None, None if make_input is make_image_cvcuda: - atol = self._get_dtype_conversion_atol_cvcuda(input_dtype, output_dtype, scale) + atol = self._get_dtype_conversion_atol_cvcuda(input_dtype, output_dtype) rtol = 0 elif input_dtype.is_floating_point and not output_dtype.is_floating_point and scale: atol, rtol = 1, 0 diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 9d2f8ae7701..6ae5466621c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -401,6 +401,13 @@ def _to_dtype_image_cvcuda( if dtype_in is None or cvc_dtype is None: raise ValueError(f"No torch or cvcuda dtype found for dtype {dtype} or {inpt.dtype}") + # torchvision will overflow the values of uint16 when converting down to uint8 without scale + # example: 300 -> 255 (cvcuda) vs 300 mod 256 = 44 (torchvision) + # since it is not equivalent, raise an error for unsupported behavior + # the workaround could be using torch for dtype conversion directly via zero-copy + if dtype_in == torch.uint16 and dtype == torch.uint8 and not scale: + raise ValueError("uint16 to uint8 conversion without scale is not supported for CV-CUDA.") + scale_val, offset = 1.0, 0.0 if scale: in_dtype_float = dtype_in.is_floating_point @@ -414,7 +421,10 @@ def _to_dtype_image_cvcuda( scale_val = float(2 ** (out_bits - in_bits)) offset = 0.0 elif in_dtype_float and not out_dtype_float: - scale_val, offset = float(_max_value(dtype)), 0.0 + # Mirror the scaling factor which torchvision uses + eps = 1e-3 + max_val = float(_max_value(dtype)) + scale_val, offset = max_val + 1.0 - eps, 0.0 else: scale_val, offset = 1.0 / float(_max_value(dtype_in)), 0.0