From 44db71c0772e5ef5758c38d0e4e8ad9995946c80 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:14:49 -0800 Subject: [PATCH 01/22] implement additional cvcuda infra for all branches to avoid duplicate setup --- torchvision/transforms/v2/_transform.py | 4 ++-- torchvision/transforms/v2/_utils.py | 3 ++- .../transforms/v2/functional/__init__.py | 2 +- .../transforms/v2/functional/_augment.py | 11 ++++++++++- .../transforms/v2/functional/_color.py | 12 +++++++++++- .../transforms/v2/functional/_geometry.py | 19 +++++++++++++++++-- torchvision/transforms/v2/functional/_misc.py | 11 +++++++++-- .../transforms/v2/functional/_utils.py | 16 ++++++++++++++++ 8 files changed, 68 insertions(+), 10 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..bec9ffcf714 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel +from .functional._utils import _get_kernel, is_cvcuda_tensor class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) def __init__(self) -> None: super().__init__() diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..765a772fe41 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,7 +15,7 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 032a993b1f0..52181e4624b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel # usort: skip +from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index a904d8d7cbd..7ce5bdc7b7e 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,4 +1,5 @@ import io +from typing import TYPE_CHECKING import PIL.Image @@ -8,7 +9,15 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def erase( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index be254c0d63a..5be9c62902a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import PIL.Image import torch from torch.nn.functional import conv2d @@ -9,7 +11,15 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4fcb7fabe0d..c029488001c 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2,7 +2,7 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import PIL.Image import torch @@ -26,7 +26,22 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import ( + _FillTypeJIT, + _get_kernel, + _import_cvcuda, + _is_cvcuda_available, + _register_five_ten_crop_kernel_internal, + _register_kernel_internal, +) + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index daf263df046..0fa05a2113c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, TYPE_CHECKING import PIL.Image import torch @@ -13,7 +13,14 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def normalize( diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ad1eddd258b..73fafaf7425 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -169,3 +169,19 @@ def _is_cvcuda_available(): return True except ImportError: return False + + +def is_cvcuda_tensor(inpt: Any) -> bool: + """ + Check if the input is a CVCUDA tensor. + + Args: + inpt: The input to check. + + Returns: + True if the input is a CV-CUDA tensor, False otherwise. + """ + if _is_cvcuda_available(): + cvcuda = _import_cvcuda() + return isinstance(inpt, cvcuda.Tensor) + return False From e3dd70022fa1c87aca7a9a98068b6e13e802a375 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:26:19 -0800 Subject: [PATCH 02/22] update make_image_cvcuda to have default batch dim --- test/common_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 8c3c9dd58a8..e7bae60c41b 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -400,8 +400,9 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_image_cvcuda(*args, **kwargs): - return to_cvcuda_tensor(make_image(*args, **kwargs)) +def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): + # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) + return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): From c035df1c6eaebcad25604f8c298a7d9eaf86864b Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:16:27 -0800 Subject: [PATCH 03/22] add stanardized setup to main for easier updating of PRs and branches --- test/common_utils.py | 21 ++++++++++++++-- test/test_transforms_v2.py | 2 +- torchvision/transforms/v2/_utils.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 24 +++++++++++++++++-- 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index e7bae60c41b..3b889e93d2e 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,13 +20,15 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available from torchvision.utils import _Image_fromarray IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" +CVCUDA_AVAILABLE = _is_cvcuda_available() CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -275,6 +277,17 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] +def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: + tensor = cvcuda_to_tensor(tensor) + if tensor.ndim != 4: + raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") + if tensor.shape[0] != 1: + raise ValueError( + f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." + ) + return tensor.squeeze(0).cpu() + + class ImagePair(TensorLikePair): def __init__( self, @@ -287,6 +300,11 @@ def __init__( if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): actual, expected = (to_image(input) for input in [actual, expected]) + # handle check for CV-CUDA Tensors + if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor): + # Use the PIL compatible tensor, so we can always compare with PIL.Image.Image + actual = cvcuda_to_pil_compatible_tensor(actual) + super().__init__(actual, expected, **other_parameters) self.mae = mae @@ -401,7 +419,6 @@ def make_image_pil(*args, **kwargs): def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): - # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 670a9d00ffb..7eba65550da 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 765a772fe41..3fc33ce5964 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6b8f19f12f4..ee562cb2aee 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) +def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: + # CV-CUDA tensor is always in NHWC layout + # get_dimensions is CHW + return [image.shape[3], image.shape[1], image.shape[2]] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + + def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image(inpt) @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels +def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: + # CV-CUDA tensor is always in NHWC layout + # get_num_channels is C + return image.shape[3] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + + def get_size(inpt: torch.Tensor) -> list[int]: if torch.jit.is_scripting(): return get_size_image(inpt) @@ -114,7 +134,7 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]: return [height, width] -def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def _get_size_cvcuda(image: "cvcuda.Tensor") -> list[int]: """Get size of `cvcuda.Tensor` with NHWC layout.""" hw = list(image.shape[-3:-1]) ndims = len(hw) @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: if CVCUDA_AVAILABLE: - _get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda) + _register_kernel_internal(get_size, cvcuda.Tensor)(_get_size_cvcuda) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False) From 98d7dfb2059eaf2c10c3f549ea45f1d27875134c Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:25:09 -0800 Subject: [PATCH 04/22] update is_cvcuda_tensor --- torchvision/transforms/v2/functional/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 73fafaf7425..44b2edeaf2d 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -181,7 +181,8 @@ def is_cvcuda_tensor(inpt: Any) -> bool: Returns: True if the input is a CV-CUDA tensor, False otherwise. """ - if _is_cvcuda_available(): + try: cvcuda = _import_cvcuda() return isinstance(inpt, cvcuda.Tensor) - return False + except ImportError: + return False From ddc116d13febdae1d53507bcde9f103a4c14eba7 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:37:03 -0800 Subject: [PATCH 05/22] add cvcuda to pil compatible to transforms by default --- test/test_transforms_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7eba65550da..87166477669 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -25,6 +25,7 @@ assert_equal, cache, cpu_and_cuda, + cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, From e51dc7eabd254261347245f4492892fd0944aae5 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:46:23 -0800 Subject: [PATCH 06/22] remove cvcuda from transform class --- torchvision/transforms/v2/_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index bec9ffcf714..ac84fcb6c82 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel, is_cvcuda_tensor +from .functional._utils import _get_kernel class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) def __init__(self) -> None: super().__init__() From 4939355a2c7421eeba95d7f155fe7953066aec6d Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:07:08 -0800 Subject: [PATCH 07/22] resolve more formatting naming --- torchvision/transforms/v2/functional/__init__.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 52181e4624b..032a993b1f0 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip +from ._utils import is_pure_tensor, register_kernel # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index e8630f788ca..af03ad018d4 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,14 +51,14 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) -def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: # CV-CUDA tensor is always in NHWC layout # get_dimensions is CHW return [image.shape[3], image.shape[1], image.shape[2]] if CVCUDA_AVAILABLE: - _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda) def get_num_channels(inpt: torch.Tensor) -> int: @@ -97,14 +97,14 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels -def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: +def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int: # CV-CUDA tensor is always in NHWC layout # get_num_channels is C return image.shape[3] if CVCUDA_AVAILABLE: - _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda) def get_size(inpt: torch.Tensor) -> list[int]: From fbea584365311ae6b56be7e4f6bbff1f834dd31a Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:15:49 -0800 Subject: [PATCH 08/22] update is cvcuda tensor impl --- torchvision/transforms/v2/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 3fc33ce5964..e803aa49c60 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,8 +15,8 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,7 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, - is_cvcuda_tensor, + _is_cvcuda_tensor, ), ) } From 0afb9cdaa8b007fb992cc0f528cbd92d61d4e20e Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:09:07 -0800 Subject: [PATCH 09/22] stash wip --- .../transforms/v2/functional/_geometry.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 0e27218bc89..2d705bbb0e2 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1535,6 +1535,25 @@ def rotate_video( return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) +def _rotate_cvcuda( + inpt: "cvcuda.Tensor", + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[list[float]] = None, + fill: _FillTypeJIT = None, +) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + + + return cvcuda.rotate(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + + +if _CVCUDA_AVAILABLE: + _register_kernel_internal(rotate, _import_cvcuda().Tensor)(rotate_cvcuda) + + def pad( inpt: torch.Tensor, padding: list[int], From 66913395b487c9decfa22ea0cdeb8c9429b66e64 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:14:49 -0800 Subject: [PATCH 10/22] implement additional cvcuda infra for all branches to avoid duplicate setup --- torchvision/transforms/v2/functional/_geometry.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 2d705bbb0e2..81756ac6797 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1544,9 +1544,7 @@ def _rotate_cvcuda( fill: _FillTypeJIT = None, ) -> "cvcuda.Tensor": cvcuda = _import_cvcuda() - - - + return cvcuda.rotate(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) From 6521570eaafcd00547fcb21953926005e34fd7f0 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:09:07 -0800 Subject: [PATCH 11/22] stash wip --- torchvision/transforms/v2/functional/_geometry.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 81756ac6797..f93ded65457 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1544,7 +1544,6 @@ def _rotate_cvcuda( fill: _FillTypeJIT = None, ) -> "cvcuda.Tensor": cvcuda = _import_cvcuda() - return cvcuda.rotate(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) From 5b451f91a4b21043e474b9def2863aed8c7d175a Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:24:46 -0800 Subject: [PATCH 12/22] wip --- .../transforms/v2/functional/_geometry.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index f93ded65457..c623b5e0c53 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1535,6 +1535,30 @@ def rotate_video( return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) +if CVCUDA_AVAILABLE: + _cvcuda_interp = { + InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR, + "bilinear": cvcuda.Interp.LINEAR, + "linear": cvcuda.Interp.LINEAR, + 2: cvcuda.Interp.LINEAR, + InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC, + "bicubic": cvcuda.Interp.CUBIC, + 3: cvcuda.Interp.CUBIC, + InterpolationMode.NEAREST: cvcuda.Interp.NEAREST, + "nearest": cvcuda.Interp.NEAREST, + 0: cvcuda.Interp.NEAREST, + InterpolationMode.BOX: cvcuda.Interp.BOX, + "box": cvcuda.Interp.BOX, + 4: cvcuda.Interp.BOX, + InterpolationMode.HAMMING: cvcuda.Interp.HAMMING, + "hamming": cvcuda.Interp.HAMMING, + 5: cvcuda.Interp.HAMMING, + InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS, + "lanczos": cvcuda.Interp.LANCZOS, + 1: cvcuda.Interp.LANCZOS, + } + + def _rotate_cvcuda( inpt: "cvcuda.Tensor", angle: float, @@ -1544,11 +1568,16 @@ def _rotate_cvcuda( fill: _FillTypeJIT = None, ) -> "cvcuda.Tensor": cvcuda = _import_cvcuda() - return cvcuda.rotate(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + interp = _cvcuda_interp.get(interpolation) + if interp is None: + raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA") + + return cvcuda.rotate(inpt, angle_deg=angle, shift=(0.0, 0.0), interpolation=interpolation) -if _CVCUDA_AVAILABLE: - _register_kernel_internal(rotate, _import_cvcuda().Tensor)(rotate_cvcuda) + +if CVCUDA_AVAILABLE: + _register_kernel_internal(rotate, _import_cvcuda().Tensor)(_rotate_cvcuda) def pad( From b8c468ccb074bc75e46ed7364fe2c0185e49f14b Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 10:31:14 -0800 Subject: [PATCH 13/22] rotate passing tests --- test/test_transforms_v2.py | 65 +++++++++++-- .../transforms/v2/functional/_geometry.py | 95 ++++++++++++++++++- 2 files changed, 152 insertions(+), 8 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f767e211125..06bfe14ad2c 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2131,6 +2131,9 @@ def test_kernel_video(self): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), ], ) def test_functional(self, make_input): @@ -2145,9 +2148,16 @@ def test_functional(self, make_input): (F.rotate_mask, tv_tensors.Mask), (F.rotate_video, tv_tensors.Video), (F.rotate_keypoints, tv_tensors.KeyPoints), + pytest.param( + F._geometry._rotate_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( @@ -2160,6 +2170,9 @@ def test_functional_signature(self, kernel, input_type): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), ], ) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -2175,12 +2188,28 @@ def test_transform(self, make_input, device): ) @pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("fill", CORRECTNESS_FILLS) - def test_functional_image_correctness(self, angle, center, interpolation, expand, fill): - image = make_image(dtype=torch.uint8, device="cpu") + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) + def test_functional_image_correctness(self, angle, center, interpolation, expand, fill, make_input): + image = make_input(dtype=torch.uint8, device="cpu") fill = adapt_fill(fill, dtype=torch.uint8) actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill) + + if make_input == make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + image = F.cvcuda_to_tensor(image) + # drop the batch dimensions + image = image.squeeze(0) + expected = F.to_image( F.rotate( F.to_pil_image(image), angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill @@ -2188,7 +2217,11 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand ) mae = (actual.float() - expected.float()).abs().mean() - assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6 + if make_input == make_image_cvcuda: + # CV-CUDA nearest interpolation differs significantly from PIL, set much higher bound + assert mae < (122.5) if interpolation is transforms.InterpolationMode.NEAREST else 6, f"MAE: {mae}" + else: + assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6, f"MAE: {mae}" @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) @pytest.mark.parametrize( @@ -2197,8 +2230,17 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand @pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("fill", CORRECTNESS_FILLS) @pytest.mark.parametrize("seed", list(range(5))) - def test_transform_image_correctness(self, center, interpolation, expand, fill, seed): - image = make_image(dtype=torch.uint8, device="cpu") + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) + def test_transform_image_correctness(self, center, interpolation, expand, fill, seed, make_input): + image = make_input(dtype=torch.uint8, device="cpu") fill = adapt_fill(fill, dtype=torch.uint8) @@ -2214,10 +2256,21 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill, actual = transform(image) torch.manual_seed(seed) + + if make_input == make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + image = F.cvcuda_to_tensor(image) + # drop the batch dimensions + image = image.squeeze(0) + expected = F.to_image(transform(F.to_pil_image(image))) mae = (actual.float() - expected.float()).abs().mean() - assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6 + if make_input == make_image_cvcuda: + # CV-CUDA nearest interpolation differs significantly from PIL, set much higher bound + assert mae < (122.5) if interpolation is transforms.InterpolationMode.NEAREST else 6, f"MAE: {mae}" + else: + assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6, f"MAE: {mae}" def _compute_output_canvas_size(self, *, expand, canvas_size, affine_matrix): if not expand: diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index c623b5e0c53..ea96a3817fe 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1572,8 +1572,99 @@ def _rotate_cvcuda( interp = _cvcuda_interp.get(interpolation) if interp is None: raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA") - - return cvcuda.rotate(inpt, angle_deg=angle, shift=(0.0, 0.0), interpolation=interpolation) + + if center is not None and len(center) != 2: + raise ValueError("Center must be a list of two floats") + + input_height, input_width = inpt.shape[1], inpt.shape[2] + num_channels = inpt.shape[3] + + if fill is None: + fill_value = [0.0] * num_channels + elif isinstance(fill, (int, float)): + fill_value = [float(fill)] * num_channels + else: + fill_value = [float(f) for f in fill] + + # Compute center offset (shift from image center) + # CV-CUDA's shift parameter is the offset from the image center + if center is None: + center_offset = (0.0, 0.0) + else: + center_offset = (center[0] - input_width / 2.0, center[1] - input_height / 2.0) + + if expand: + # Calculate the expanded output size using the same logic as torch + center_f = [0.0, 0.0] + if center is not None: + center_f = [(c - s * 0.5) for c, s in zip(center, [input_width, input_height])] + matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + output_width, output_height = _compute_affine_output_size(matrix, input_width, input_height) + + # compute padding + pad_left = (output_width - input_width) // 2 + pad_right = output_width - input_width - pad_left + pad_top = (output_height - input_height) // 2 + pad_bottom = output_height - input_height - pad_top + padded = cvcuda.copymakeborder( + inpt, + border_mode=cvcuda.Border.CONSTANT, + border_value=fill_value, + top=pad_top, + bottom=pad_bottom, + left=pad_left, + right=pad_right, + ) + + # get the new center offset + # The center of the original image has moved by (pad_left, pad_top) + new_center_x = (input_width / 2.0 + center_offset[0]) + pad_left + new_center_y = (input_height / 2.0 + center_offset[1]) + pad_top + padded_shift = (new_center_x - output_width / 2.0, new_center_y - output_height / 2.0) + + return cvcuda.rotate(padded, angle_deg=angle, shift=padded_shift, interpolation=interp) + + elif fill is not None and fill_value != [0.0] * num_channels: + # For non-zero fill without expand: + # 1. Pad with fill value to create a larger canvas + # 2. Rotate around the appropriate center + # 3. Crop back to original size + + # compute padding + diag = int(math.ceil(math.sqrt(input_width**2 + input_height**2))) + pad_left = (diag - input_width) // 2 + pad_right = diag - input_width - pad_left + pad_top = (diag - input_height) // 2 + pad_bottom = diag - input_height - pad_top + padded = cvcuda.copymakeborder( + inpt, + border_mode=cvcuda.Border.CONSTANT, + border_value=fill_value, + top=pad_top, + bottom=pad_bottom, + left=pad_left, + right=pad_right, + ) + + # get the new center offset + padded_width, padded_height = padded.shape[2], padded.shape[1] + new_center_x = (input_width / 2.0 + center_offset[0]) + pad_left + new_center_y = (input_height / 2.0 + center_offset[1]) + pad_top + padded_shift = (new_center_x - padded_width / 2.0, new_center_y - padded_height / 2.0) + + # rotate the padded image + rotated = cvcuda.rotate(padded, angle_deg=angle, shift=padded_shift, interpolation=interp) + + # crop back to original size + crop_left = (rotated.shape[2] - input_width) // 2 + crop_top = (rotated.shape[1] - input_height) // 2 + return cvcuda.customcrop( + rotated, + rect=cvcuda.RectI(x=crop_left, y=crop_top, width=input_width, height=input_height), + ) + + else: + return cvcuda.rotate(inpt, angle_deg=angle, shift=center_offset, interpolation=interp) if CVCUDA_AVAILABLE: From 886104257134bae6002039c7609cff1569963341 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:12:11 -0800 Subject: [PATCH 14/22] update rotate to use correct logic --- test/test_transforms_v2.py | 5 +- .../transforms/v2/functional/_geometry.py | 122 +++++++----------- 2 files changed, 47 insertions(+), 80 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 06bfe14ad2c..68f69120a70 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2205,10 +2205,7 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill) if make_input == make_image_cvcuda: - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - image = F.cvcuda_to_tensor(image) - # drop the batch dimensions - image = image.squeeze(0) + image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image( F.rotate( diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index ea96a3817fe..45fe60fe77b 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1569,13 +1569,18 @@ def _rotate_cvcuda( ) -> "cvcuda.Tensor": cvcuda = _import_cvcuda() + angle = angle % 360 + + if angle == 0: + return inpt + + if angle == 180: + return cvcuda.flip(inpt, flipCode=-1) + interp = _cvcuda_interp.get(interpolation) if interp is None: raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA") - if center is not None and len(center) != 2: - raise ValueError("Center must be a list of two floats") - input_height, input_width = inpt.shape[1], inpt.shape[2] num_channels = inpt.shape[3] @@ -1586,85 +1591,50 @@ def _rotate_cvcuda( else: fill_value = [float(f) for f in fill] - # Compute center offset (shift from image center) - # CV-CUDA's shift parameter is the offset from the image center + # Determine the rotation center + # torchvision uses image center by default, cvcuda rotates around upper-left (0,0) + # We need to calculate a shift to effectively rotate around the desired center if center is None: - center_offset = (0.0, 0.0) + cx, cy = input_width / 2.0, input_height / 2.0 else: - center_offset = (center[0] - input_width / 2.0, center[1] - input_height / 2.0) + cx, cy = float(center[0]), float(center[1]) - if expand: - # Calculate the expanded output size using the same logic as torch - center_f = [0.0, 0.0] - if center is not None: - center_f = [(c - s * 0.5) for c, s in zip(center, [input_width, input_height])] - matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) - output_width, output_height = _compute_affine_output_size(matrix, input_width, input_height) - - # compute padding - pad_left = (output_width - input_width) // 2 - pad_right = output_width - input_width - pad_left - pad_top = (output_height - input_height) // 2 - pad_bottom = output_height - input_height - pad_top - padded = cvcuda.copymakeborder( - inpt, - border_mode=cvcuda.Border.CONSTANT, - border_value=fill_value, - top=pad_top, - bottom=pad_bottom, - left=pad_left, - right=pad_right, - ) + angle_rad = math.radians(angle) + cos_angle = math.cos(angle_rad) + sin_angle = math.sin(angle_rad) - # get the new center offset - # The center of the original image has moved by (pad_left, pad_top) - new_center_x = (input_width / 2.0 + center_offset[0]) + pad_left - new_center_y = (input_height / 2.0 + center_offset[1]) + pad_top - padded_shift = (new_center_x - output_width / 2.0, new_center_y - output_height / 2.0) - - return cvcuda.rotate(padded, angle_deg=angle, shift=padded_shift, interpolation=interp) - - elif fill is not None and fill_value != [0.0] * num_channels: - # For non-zero fill without expand: - # 1. Pad with fill value to create a larger canvas - # 2. Rotate around the appropriate center - # 3. Crop back to original size - - # compute padding - diag = int(math.ceil(math.sqrt(input_width**2 + input_height**2))) - pad_left = (diag - input_width) // 2 - pad_right = diag - input_width - pad_left - pad_top = (diag - input_height) // 2 - pad_bottom = diag - input_height - pad_top - padded = cvcuda.copymakeborder( - inpt, - border_mode=cvcuda.Border.CONSTANT, - border_value=fill_value, - top=pad_top, - bottom=pad_bottom, - left=pad_left, - right=pad_right, - ) + # if we are not expanding, simple case + if not expand: + shift_x = (1 - cos_angle) * cx - sin_angle * cy + shift_y = sin_angle * cx + (1 - cos_angle) * cy - # get the new center offset - padded_width, padded_height = padded.shape[2], padded.shape[1] - new_center_x = (input_width / 2.0 + center_offset[0]) + pad_left - new_center_y = (input_height / 2.0 + center_offset[1]) + pad_top - padded_shift = (new_center_x - padded_width / 2.0, new_center_y - padded_height / 2.0) - - # rotate the padded image - rotated = cvcuda.rotate(padded, angle_deg=angle, shift=padded_shift, interpolation=interp) - - # crop back to original size - crop_left = (rotated.shape[2] - input_width) // 2 - crop_top = (rotated.shape[1] - input_height) // 2 - return cvcuda.customcrop( - rotated, - rect=cvcuda.RectI(x=crop_left, y=crop_top, width=input_width, height=input_height), - ) + return cvcuda.rotate(inpt, angle_deg=angle, shift=(shift_x, shift_y), interpolation=interp) - else: - return cvcuda.rotate(inpt, angle_deg=angle, shift=center_offset, interpolation=interp) + # if we need to expand, use much of the same logic as torchvision, for output size/pad + matrix = _get_inverse_affine_matrix([0.0, 0.0], -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + output_width, output_height = _compute_affine_output_size(matrix, input_width, input_height) + + pad_left = (output_width - input_width) // 2 + pad_right = output_width - input_width - pad_left + pad_top = (output_height - input_height) // 2 + pad_bottom = output_height - input_height - pad_top + + padded = cvcuda.copymakeborder( + inpt, + top=pad_top, + left=pad_left, + bottom=pad_bottom, + right=pad_right, + border_mode=cvcuda.Border.CONSTANT, + border_value=fill_value, + ) + + new_cx = pad_left + cx + new_cy = pad_top + cy + shift_x = (1 - cos_angle) * new_cx - sin_angle * new_cy + shift_y = sin_angle * new_cx + (1 - cos_angle) * new_cy + + return cvcuda.rotate(padded, angle_deg=angle, shift=(shift_x, shift_y), interpolation=interp) if CVCUDA_AVAILABLE: From 550656f6dd1f6518b988e31a29b65aa99ba6e45b Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:24:39 -0800 Subject: [PATCH 15/22] cvcuda rotate verified correct visualizly and passing all tests --- test/test_transforms_v2.py | 15 +++++++-------- torchvision/transforms/v2/functional/_geometry.py | 6 +++++- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 68f69120a70..592d9fca3cb 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2204,7 +2204,8 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill) - if make_input == make_image_cvcuda: + if make_input is make_image_cvcuda: + actual = cvcuda_to_pil_compatible_tensor(actual) image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image( @@ -2214,7 +2215,7 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand ) mae = (actual.float() - expected.float()).abs().mean() - if make_input == make_image_cvcuda: + if make_input is make_image_cvcuda: # CV-CUDA nearest interpolation differs significantly from PIL, set much higher bound assert mae < (122.5) if interpolation is transforms.InterpolationMode.NEAREST else 6, f"MAE: {mae}" else: @@ -2254,16 +2255,14 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill, torch.manual_seed(seed) - if make_input == make_image_cvcuda: - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - image = F.cvcuda_to_tensor(image) - # drop the batch dimensions - image = image.squeeze(0) + if make_input is make_image_cvcuda: + actual = cvcuda_to_pil_compatible_tensor(actual) + image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(transform(F.to_pil_image(image))) mae = (actual.float() - expected.float()).abs().mean() - if make_input == make_image_cvcuda: + if make_input is make_image_cvcuda: # CV-CUDA nearest interpolation differs significantly from PIL, set much higher bound assert mae < (122.5) if interpolation is transforms.InterpolationMode.NEAREST else 6, f"MAE: {mae}" else: diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 45fe60fe77b..e0b40b89416 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1596,8 +1596,11 @@ def _rotate_cvcuda( # We need to calculate a shift to effectively rotate around the desired center if center is None: cx, cy = input_width / 2.0, input_height / 2.0 + center_f = [0.0, 0.0] else: cx, cy = float(center[0]), float(center[1]) + # Convert to image-center-relative coordinates (same as torchvision) + center_f = [cx - input_width * 0.5, cy - input_height * 0.5] angle_rad = math.radians(angle) cos_angle = math.cos(angle_rad) @@ -1611,7 +1614,8 @@ def _rotate_cvcuda( return cvcuda.rotate(inpt, angle_deg=angle, shift=(shift_x, shift_y), interpolation=interp) # if we need to expand, use much of the same logic as torchvision, for output size/pad - matrix = _get_inverse_affine_matrix([0.0, 0.0], -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + # Use center_f (image-center-relative coords) to match torchvision's output size calculation + matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) output_width, output_height = _compute_affine_output_size(matrix, input_width, input_height) pad_left = (output_width - input_width) // 2 From 5fbeac3e6411ebf8e232aa6fc4bbbe9ccc0c04fa Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:45:25 -0800 Subject: [PATCH 16/22] move transformed type check to Rotate transform --- torchvision/transforms/v2/_geometry.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 96166e05e9a..07b0692fafe 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -29,6 +29,7 @@ is_pure_tensor, query_size, ) +from .functional._utils import is_cvcuda_tensor CVCUDA_AVAILABLE = _is_cvcuda_available() @@ -606,6 +607,8 @@ class RandomRotation(Transform): _v1_transform_cls = _transforms.RandomRotation + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__( self, degrees: Union[numbers.Number, Sequence], From ea0bdec2f4869e392e5cef5b058faf6218e00a75 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 13:58:07 -0800 Subject: [PATCH 17/22] update rotate to main standards --- test/test_transforms_v2.py | 15 +++++++-------- torchvision/transforms/v2/_geometry.py | 4 ++-- torchvision/transforms/v2/functional/_geometry.py | 4 ++-- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 592d9fca3cb..2475314e74b 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -25,7 +25,6 @@ assert_equal, cache, cpu_and_cuda, - cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, @@ -2149,14 +2148,14 @@ def test_functional(self, make_input): (F.rotate_video, tv_tensors.Video), (F.rotate_keypoints, tv_tensors.KeyPoints), pytest.param( - F._geometry._rotate_cvcuda, - "cvcuda.Tensor", + F._geometry._rotate_image_cvcuda, + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), ), ], ) def test_functional_signature(self, kernel, input_type): - if input_type == "cvcuda.Tensor": + if kernel is F._geometry._rotate_image_cvcuda: input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type) @@ -2205,8 +2204,8 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill) if make_input is make_image_cvcuda: - actual = cvcuda_to_pil_compatible_tensor(actual) - image = cvcuda_to_pil_compatible_tensor(image) + actual = F.cvcuda_to_tensor(actual)[0].cpu() + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.to_image( F.rotate( @@ -2256,8 +2255,8 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill, torch.manual_seed(seed) if make_input is make_image_cvcuda: - actual = cvcuda_to_pil_compatible_tensor(actual) - image = cvcuda_to_pil_compatible_tensor(image) + actual = F.cvcuda_to_tensor(actual)[0].cpu() + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.to_image(transform(F.to_pil_image(image))) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 07b0692fafe..ea0cb58843f 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -29,7 +29,6 @@ is_pure_tensor, query_size, ) -from .functional._utils import is_cvcuda_tensor CVCUDA_AVAILABLE = _is_cvcuda_available() @@ -607,7 +606,8 @@ class RandomRotation(Transform): _v1_transform_cls = _transforms.RandomRotation - _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) def __init__( self, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index e0b40b89416..671f9207938 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1559,7 +1559,7 @@ def rotate_video( } -def _rotate_cvcuda( +def _rotate_image_cvcuda( inpt: "cvcuda.Tensor", angle: float, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, @@ -1642,7 +1642,7 @@ def _rotate_cvcuda( if CVCUDA_AVAILABLE: - _register_kernel_internal(rotate, _import_cvcuda().Tensor)(_rotate_cvcuda) + _register_kernel_internal(rotate, _import_cvcuda().Tensor)(_rotate_image_cvcuda) def pad( From 5aaea08ff017667349048b1a9b1aa2b74fa21ff2 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 13:58:59 -0800 Subject: [PATCH 18/22] remove unneeed cvcuda refs --- torchvision/transforms/v2/functional/_augment.py | 11 +---------- torchvision/transforms/v2/functional/_color.py | 12 +----------- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 7ce5bdc7b7e..a904d8d7cbd 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,5 +1,4 @@ import io -from typing import TYPE_CHECKING import PIL.Image @@ -9,15 +8,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal - - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal def erase( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 5be9c62902a..be254c0d63a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,5 +1,3 @@ -from typing import TYPE_CHECKING - import PIL.Image import torch from torch.nn.functional import conv2d @@ -11,15 +9,7 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal - - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: From 1fc4d6dd6135fbf6ab12eeb67bfa259b247fd596 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 14:04:06 -0800 Subject: [PATCH 19/22] refacotr interp into helper --- .../transforms/v2/functional/_geometry.py | 29 +---------- .../transforms/v2/functional/_utils.py | 48 ++++++++++++++++++- 2 files changed, 49 insertions(+), 28 deletions(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 671f9207938..9d7ac5cbbd0 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -28,6 +28,7 @@ from ._utils import ( _FillTypeJIT, + _get_cvcuda_interp, _get_kernel, _import_cvcuda, _is_cvcuda_available, @@ -1535,30 +1536,6 @@ def rotate_video( return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) -if CVCUDA_AVAILABLE: - _cvcuda_interp = { - InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR, - "bilinear": cvcuda.Interp.LINEAR, - "linear": cvcuda.Interp.LINEAR, - 2: cvcuda.Interp.LINEAR, - InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC, - "bicubic": cvcuda.Interp.CUBIC, - 3: cvcuda.Interp.CUBIC, - InterpolationMode.NEAREST: cvcuda.Interp.NEAREST, - "nearest": cvcuda.Interp.NEAREST, - 0: cvcuda.Interp.NEAREST, - InterpolationMode.BOX: cvcuda.Interp.BOX, - "box": cvcuda.Interp.BOX, - 4: cvcuda.Interp.BOX, - InterpolationMode.HAMMING: cvcuda.Interp.HAMMING, - "hamming": cvcuda.Interp.HAMMING, - 5: cvcuda.Interp.HAMMING, - InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS, - "lanczos": cvcuda.Interp.LANCZOS, - 1: cvcuda.Interp.LANCZOS, - } - - def _rotate_image_cvcuda( inpt: "cvcuda.Tensor", angle: float, @@ -1577,9 +1554,7 @@ def _rotate_image_cvcuda( if angle == 180: return cvcuda.flip(inpt, flipCode=-1) - interp = _cvcuda_interp.get(interpolation) - if interp is None: - raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA") + interp = _get_cvcuda_interp(interpolation) input_height, input_width = inpt.shape[1], inpt.shape[2] num_channels = inpt.shape[3] diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 11480b30ef9..963f50e08cb 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,9 +1,13 @@ import functools from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torchvision import tv_tensors +from torchvision.transforms.functional import InterpolationMode + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] _FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillTypeJIT = Optional[list[float]] @@ -177,3 +181,45 @@ def _is_cvcuda_tensor(inpt: Any) -> bool: return isinstance(inpt, cvcuda.Tensor) except ImportError: return False + + +_interpolation_mode_to_cvcuda_interp: dict[InterpolationMode | str | int, "cvcuda.Interp"] = {} + + +def _populate_interpolation_mode_to_cvcuda_interp(): + cvcuda = _import_cvcuda() + + global _interpolation_mode_to_cvcuda_interp + + _interpolation_mode_to_cvcuda_interp = { + InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR, + "bilinear": cvcuda.Interp.LINEAR, + "linear": cvcuda.Interp.LINEAR, + 2: cvcuda.Interp.LINEAR, + InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC, + "bicubic": cvcuda.Interp.CUBIC, + 3: cvcuda.Interp.CUBIC, + InterpolationMode.NEAREST: cvcuda.Interp.NEAREST, + "nearest": cvcuda.Interp.NEAREST, + 0: cvcuda.Interp.NEAREST, + InterpolationMode.BOX: cvcuda.Interp.BOX, + "box": cvcuda.Interp.BOX, + 4: cvcuda.Interp.BOX, + InterpolationMode.HAMMING: cvcuda.Interp.HAMMING, + "hamming": cvcuda.Interp.HAMMING, + 5: cvcuda.Interp.HAMMING, + InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS, + "lanczos": cvcuda.Interp.LANCZOS, + 1: cvcuda.Interp.LANCZOS, + } + + +def _get_cvcuda_interp(interpolation: InterpolationMode | str | int) -> "cvcuda.Interp": + if len(_interpolation_mode_to_cvcuda_interp) == 0: + _populate_interpolation_mode_to_cvcuda_interp() + + interp = _interpolation_mode_to_cvcuda_interp.get(interpolation) + if interp is None: + raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA") + + return interp From ec2a97cf101631ab1e147f3f8b0527bdc9900796 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 14:06:40 -0800 Subject: [PATCH 20/22] drop more unused refs to cvcuda --- torchvision/transforms/v2/functional/_misc.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 0fa05a2113c..daf263df046 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional, TYPE_CHECKING +from typing import Optional import PIL.Image import torch @@ -13,14 +13,7 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor def normalize( From cfc4412eb934558dc265e7ac16881bd80dd873c3 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 17:46:09 -0800 Subject: [PATCH 21/22] update to resize interp func --- torchvision/transforms/v2/functional/_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 963f50e08cb..4111416df79 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -191,6 +191,8 @@ def _populate_interpolation_mode_to_cvcuda_interp(): global _interpolation_mode_to_cvcuda_interp + # CV-CUDA's NEAREST matches PyTorch's 'nearest-exact' (PIL-style) + # not PyTorch's 'nearest' (OpenCV-style). _interpolation_mode_to_cvcuda_interp = { InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR, "bilinear": cvcuda.Interp.LINEAR, @@ -202,6 +204,8 @@ def _populate_interpolation_mode_to_cvcuda_interp(): InterpolationMode.NEAREST: cvcuda.Interp.NEAREST, "nearest": cvcuda.Interp.NEAREST, 0: cvcuda.Interp.NEAREST, + InterpolationMode.NEAREST_EXACT: cvcuda.Interp.NEAREST, + "nearest-exact": cvcuda.Interp.NEAREST, InterpolationMode.BOX: cvcuda.Interp.BOX, "box": cvcuda.Interp.BOX, 4: cvcuda.Interp.BOX, From 1a2d5723a00fbb705ae1221b287d5035669b7982 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 18:13:20 -0800 Subject: [PATCH 22/22] refactor interp setup --- .../transforms/v2/functional/_utils.py | 54 ++++++++----------- 1 file changed, 21 insertions(+), 33 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 4111416df79..a1742ba149f 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -186,41 +186,29 @@ def _is_cvcuda_tensor(inpt: Any) -> bool: _interpolation_mode_to_cvcuda_interp: dict[InterpolationMode | str | int, "cvcuda.Interp"] = {} -def _populate_interpolation_mode_to_cvcuda_interp(): - cvcuda = _import_cvcuda() - - global _interpolation_mode_to_cvcuda_interp - - # CV-CUDA's NEAREST matches PyTorch's 'nearest-exact' (PIL-style) - # not PyTorch's 'nearest' (OpenCV-style). - _interpolation_mode_to_cvcuda_interp = { - InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR, - "bilinear": cvcuda.Interp.LINEAR, - "linear": cvcuda.Interp.LINEAR, - 2: cvcuda.Interp.LINEAR, - InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC, - "bicubic": cvcuda.Interp.CUBIC, - 3: cvcuda.Interp.CUBIC, - InterpolationMode.NEAREST: cvcuda.Interp.NEAREST, - "nearest": cvcuda.Interp.NEAREST, - 0: cvcuda.Interp.NEAREST, - InterpolationMode.NEAREST_EXACT: cvcuda.Interp.NEAREST, - "nearest-exact": cvcuda.Interp.NEAREST, - InterpolationMode.BOX: cvcuda.Interp.BOX, - "box": cvcuda.Interp.BOX, - 4: cvcuda.Interp.BOX, - InterpolationMode.HAMMING: cvcuda.Interp.HAMMING, - "hamming": cvcuda.Interp.HAMMING, - 5: cvcuda.Interp.HAMMING, - InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS, - "lanczos": cvcuda.Interp.LANCZOS, - 1: cvcuda.Interp.LANCZOS, - } - - def _get_cvcuda_interp(interpolation: InterpolationMode | str | int) -> "cvcuda.Interp": if len(_interpolation_mode_to_cvcuda_interp) == 0: - _populate_interpolation_mode_to_cvcuda_interp() + cvcuda = _import_cvcuda() + _interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST_EXACT] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp[InterpolationMode.BILINEAR] = cvcuda.Interp.LINEAR + _interpolation_mode_to_cvcuda_interp[InterpolationMode.BICUBIC] = cvcuda.Interp.CUBIC + _interpolation_mode_to_cvcuda_interp[InterpolationMode.BOX] = cvcuda.Interp.BOX + _interpolation_mode_to_cvcuda_interp[InterpolationMode.HAMMING] = cvcuda.Interp.HAMMING + _interpolation_mode_to_cvcuda_interp[InterpolationMode.LANCZOS] = cvcuda.Interp.LANCZOS + _interpolation_mode_to_cvcuda_interp["nearest"] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp["nearest-exact"] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp["bilinear"] = cvcuda.Interp.LINEAR + _interpolation_mode_to_cvcuda_interp["bicubic"] = cvcuda.Interp.CUBIC + _interpolation_mode_to_cvcuda_interp["box"] = cvcuda.Interp.BOX + _interpolation_mode_to_cvcuda_interp["hamming"] = cvcuda.Interp.HAMMING + _interpolation_mode_to_cvcuda_interp["lanczos"] = cvcuda.Interp.LANCZOS + _interpolation_mode_to_cvcuda_interp[0] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp[2] = cvcuda.Interp.LINEAR + _interpolation_mode_to_cvcuda_interp[3] = cvcuda.Interp.CUBIC + _interpolation_mode_to_cvcuda_interp[4] = cvcuda.Interp.BOX + _interpolation_mode_to_cvcuda_interp[5] = cvcuda.Interp.HAMMING + _interpolation_mode_to_cvcuda_interp[1] = cvcuda.Interp.LANCZOS interp = _interpolation_mode_to_cvcuda_interp.get(interpolation) if interp is None: