From 44db71c0772e5ef5758c38d0e4e8ad9995946c80 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:14:49 -0800 Subject: [PATCH 01/22] implement additional cvcuda infra for all branches to avoid duplicate setup --- torchvision/transforms/v2/_transform.py | 4 ++-- torchvision/transforms/v2/_utils.py | 3 ++- .../transforms/v2/functional/__init__.py | 2 +- .../transforms/v2/functional/_augment.py | 11 ++++++++++- .../transforms/v2/functional/_color.py | 12 +++++++++++- .../transforms/v2/functional/_geometry.py | 19 +++++++++++++++++-- torchvision/transforms/v2/functional/_misc.py | 11 +++++++++-- .../transforms/v2/functional/_utils.py | 16 ++++++++++++++++ 8 files changed, 68 insertions(+), 10 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..bec9ffcf714 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel +from .functional._utils import _get_kernel, is_cvcuda_tensor class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) def __init__(self) -> None: super().__init__() diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..765a772fe41 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,7 +15,7 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 032a993b1f0..52181e4624b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel # usort: skip +from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index a904d8d7cbd..7ce5bdc7b7e 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,4 +1,5 @@ import io +from typing import TYPE_CHECKING import PIL.Image @@ -8,7 +9,15 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def erase( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index be254c0d63a..5be9c62902a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import PIL.Image import torch from torch.nn.functional import conv2d @@ -9,7 +11,15 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4fcb7fabe0d..c029488001c 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2,7 +2,7 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import PIL.Image import torch @@ -26,7 +26,22 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import ( + _FillTypeJIT, + _get_kernel, + _import_cvcuda, + _is_cvcuda_available, + _register_five_ten_crop_kernel_internal, + _register_kernel_internal, +) + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index daf263df046..0fa05a2113c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, TYPE_CHECKING import PIL.Image import torch @@ -13,7 +13,14 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def normalize( diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ad1eddd258b..73fafaf7425 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -169,3 +169,19 @@ def _is_cvcuda_available(): return True except ImportError: return False + + +def is_cvcuda_tensor(inpt: Any) -> bool: + """ + Check if the input is a CVCUDA tensor. + + Args: + inpt: The input to check. + + Returns: + True if the input is a CV-CUDA tensor, False otherwise. + """ + if _is_cvcuda_available(): + cvcuda = _import_cvcuda() + return isinstance(inpt, cvcuda.Tensor) + return False From e3dd70022fa1c87aca7a9a98068b6e13e802a375 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:26:19 -0800 Subject: [PATCH 02/22] update make_image_cvcuda to have default batch dim --- test/common_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 8c3c9dd58a8..e7bae60c41b 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -400,8 +400,9 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_image_cvcuda(*args, **kwargs): - return to_cvcuda_tensor(make_image(*args, **kwargs)) +def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): + # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) + return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): From c035df1c6eaebcad25604f8c298a7d9eaf86864b Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:16:27 -0800 Subject: [PATCH 03/22] add stanardized setup to main for easier updating of PRs and branches --- test/common_utils.py | 21 ++++++++++++++-- test/test_transforms_v2.py | 2 +- torchvision/transforms/v2/_utils.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 24 +++++++++++++++++-- 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index e7bae60c41b..3b889e93d2e 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,13 +20,15 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available from torchvision.utils import _Image_fromarray IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" +CVCUDA_AVAILABLE = _is_cvcuda_available() CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -275,6 +277,17 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] +def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: + tensor = cvcuda_to_tensor(tensor) + if tensor.ndim != 4: + raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") + if tensor.shape[0] != 1: + raise ValueError( + f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." + ) + return tensor.squeeze(0).cpu() + + class ImagePair(TensorLikePair): def __init__( self, @@ -287,6 +300,11 @@ def __init__( if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): actual, expected = (to_image(input) for input in [actual, expected]) + # handle check for CV-CUDA Tensors + if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor): + # Use the PIL compatible tensor, so we can always compare with PIL.Image.Image + actual = cvcuda_to_pil_compatible_tensor(actual) + super().__init__(actual, expected, **other_parameters) self.mae = mae @@ -401,7 +419,6 @@ def make_image_pil(*args, **kwargs): def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): - # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 670a9d00ffb..7eba65550da 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 765a772fe41..3fc33ce5964 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6b8f19f12f4..ee562cb2aee 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) +def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: + # CV-CUDA tensor is always in NHWC layout + # get_dimensions is CHW + return [image.shape[3], image.shape[1], image.shape[2]] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + + def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image(inpt) @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels +def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: + # CV-CUDA tensor is always in NHWC layout + # get_num_channels is C + return image.shape[3] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + + def get_size(inpt: torch.Tensor) -> list[int]: if torch.jit.is_scripting(): return get_size_image(inpt) @@ -114,7 +134,7 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]: return [height, width] -def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def _get_size_cvcuda(image: "cvcuda.Tensor") -> list[int]: """Get size of `cvcuda.Tensor` with NHWC layout.""" hw = list(image.shape[-3:-1]) ndims = len(hw) @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: if CVCUDA_AVAILABLE: - _get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda) + _register_kernel_internal(get_size, cvcuda.Tensor)(_get_size_cvcuda) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False) From 98d7dfb2059eaf2c10c3f549ea45f1d27875134c Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:25:09 -0800 Subject: [PATCH 04/22] update is_cvcuda_tensor --- torchvision/transforms/v2/functional/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 73fafaf7425..44b2edeaf2d 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -181,7 +181,8 @@ def is_cvcuda_tensor(inpt: Any) -> bool: Returns: True if the input is a CV-CUDA tensor, False otherwise. """ - if _is_cvcuda_available(): + try: cvcuda = _import_cvcuda() return isinstance(inpt, cvcuda.Tensor) - return False + except ImportError: + return False From ddc116d13febdae1d53507bcde9f103a4c14eba7 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:37:03 -0800 Subject: [PATCH 05/22] add cvcuda to pil compatible to transforms by default --- test/test_transforms_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7eba65550da..87166477669 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -25,6 +25,7 @@ assert_equal, cache, cpu_and_cuda, + cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, From e51dc7eabd254261347245f4492892fd0944aae5 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:46:23 -0800 Subject: [PATCH 06/22] remove cvcuda from transform class --- torchvision/transforms/v2/_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index bec9ffcf714..ac84fcb6c82 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel, is_cvcuda_tensor +from .functional._utils import _get_kernel class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) def __init__(self) -> None: super().__init__() From 4939355a2c7421eeba95d7f155fe7953066aec6d Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:07:08 -0800 Subject: [PATCH 07/22] resolve more formatting naming --- torchvision/transforms/v2/functional/__init__.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 52181e4624b..032a993b1f0 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip +from ._utils import is_pure_tensor, register_kernel # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index e8630f788ca..af03ad018d4 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,14 +51,14 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) -def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: # CV-CUDA tensor is always in NHWC layout # get_dimensions is CHW return [image.shape[3], image.shape[1], image.shape[2]] if CVCUDA_AVAILABLE: - _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda) def get_num_channels(inpt: torch.Tensor) -> int: @@ -97,14 +97,14 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels -def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: +def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int: # CV-CUDA tensor is always in NHWC layout # get_num_channels is C return image.shape[3] if CVCUDA_AVAILABLE: - _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda) def get_size(inpt: torch.Tensor) -> list[int]: From fbea584365311ae6b56be7e4f6bbff1f834dd31a Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:15:49 -0800 Subject: [PATCH 08/22] update is cvcuda tensor impl --- torchvision/transforms/v2/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 3fc33ce5964..e803aa49c60 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,8 +15,8 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,7 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, - is_cvcuda_tensor, + _is_cvcuda_tensor, ), ) } From ffe7a140f28f854aec52e0318738aec220ca8ebd Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 18 Nov 2025 12:14:45 -0800 Subject: [PATCH 09/22] initial cvcuda crop implementation, only minimal tests so far --- test/test_transforms_v2.py | 32 +++++++++++++++++++ .../transforms/v2/functional/__init__.py | 2 ++ .../transforms/v2/functional/_geometry.py | 32 +++++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f767e211125..ca7aea8a79b 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3765,6 +3765,18 @@ def test_errors(self): transforms.RandomCrop([10, 12], padding=1, padding_mode="abc") +@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda not available") +@needs_cuda +class TestCropCVCUDA: + def test_functional(self): + check_functional( + F.crop, make_image_cvcuda(TestCrop.INPUT_SIZE, batch_dims=(1,)), **TestCrop.MINIMAL_CROP_KWARGS + ) + + def test_functional_signature(self): + check_functional_kernel_signature_match(F.crop, kernel=F.crop_cvcuda, input_type=cvcuda.Tensor) + + class TestErase: INPUT_SIZE = (17, 11) FUNCTIONAL_KWARGS = dict( @@ -5045,6 +5057,26 @@ def test_keypoints_correctness(self, output_size, dtype, device, fn): assert_equal(actual, expected) +@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda not available") +@needs_cuda +class TestCenterCropCVCUDA: + def test_functional(self): + check_functional( + F.center_crop, + make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)), + output_size=TestCenterCrop.OUTPUT_SIZES[0], + ) + + def test_functional_signature(self): + check_functional_kernel_signature_match(F.center_crop, kernel=F.center_crop_cvcuda, input_type=cvcuda.Tensor) + + def test_transform(self): + check_transform( + transforms.CenterCrop(TestCenterCrop.OUTPUT_SIZES[0]), + make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)), + ) + + class TestPerspective: COEFFICIENTS = [ [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 032a993b1f0..9b437dfd8a8 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -76,12 +76,14 @@ affine_video, center_crop, center_crop_bounding_boxes, + center_crop_cvcuda, center_crop_image, center_crop_keypoints, center_crop_mask, center_crop_video, crop, crop_bounding_boxes, + crop_cvcuda, crop_image, crop_keypoints, crop_mask, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 0e27218bc89..3476ec71db2 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1924,6 +1924,23 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int return crop_image(video, top, left, height, width) +def crop_cvcuda( + image: "cvcuda.Tensor", + top: int, + left: int, + height: int, + width: int, +) -> "cvcuda.Tensor": + return cvcuda.customcrop( + image, + cvcuda.RectI(x=left, y=top, width=width, height=height), + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(crop, cvcuda.Tensor)(crop_cvcuda) + + def perspective( inpt: torch.Tensor, startpoints: Optional[list[list[int]]], @@ -2674,6 +2691,21 @@ def center_crop_video(video: torch.Tensor, output_size: list[int]) -> torch.Tens return center_crop_image(video, output_size) +def center_crop_cvcuda( + image: "cvcuda.Tensor", + output_size: list[int], +) -> "cvcuda.Tensor": + crop_height, crop_width = _center_crop_parse_output_size(output_size) + return cvcuda.center_crop( + image, + crop_size=(crop_width, crop_height), + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(center_crop, cvcuda.Tensor)(center_crop_cvcuda) + + def resized_crop( inpt: torch.Tensor, top: int, From 9133c3daabadd2f5bb1af31776dc4a512a22334b Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 18 Nov 2025 13:10:40 -0800 Subject: [PATCH 10/22] add padding to centercrop and if needed to crop --- .../transforms/v2/functional/_geometry.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 3476ec71db2..ce3f1b476c0 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1931,6 +1931,33 @@ def crop_cvcuda( height: int, width: int, ) -> "cvcuda.Tensor": + image_height, image_width, channels = image.shape[1:] + top_diff = 0 + left_diff = 0 + height_diff = 0 + width_diff = 0 + if top < 0: + top_diff = -1 * top + if left < 0: + left_diff = -1 * left + if top + height > image_height: + height_diff = top + height - image_height + if left + width > image_width: + width_diff = left + width - image_width + if top_diff or left_diff or height_diff or width_diff: + image = cvcuda.copymakeborder( + image, + top=top_diff, + left=left_diff, + bottom=height_diff, + right=width_diff, + border_mode=cvcuda.Border.CONSTANT, + value=[0.0] * channels, + ) + top = 0 + left = 0 + height = image_height + width = image_width return cvcuda.customcrop( image, cvcuda.RectI(x=left, y=top, width=width, height=height), @@ -2696,6 +2723,21 @@ def center_crop_cvcuda( output_size: list[int], ) -> "cvcuda.Tensor": crop_height, crop_width = _center_crop_parse_output_size(output_size) + # we only allow cvcuda conversion for 4 ndim, and always use nhwc layout + image_height = image.shape[1] + image_width = image.shape[2] + channels = image.shape[3] + if crop_height > image_height or crop_width > image_width: + padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + image = cvcuda.copymakeborder( + image, + top=padding_ltrb[1], + left=padding_ltrb[0], + bottom=padding_ltrb[3], + right=padding_ltrb[2], + border_mode=cvcuda.Border.CONSTANT, + value=[0.0] * channels, + ) return cvcuda.center_crop( image, crop_size=(crop_width, crop_height), From 878d2aed7513a7d19563cb7033fc0d8a77af7e9f Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 18 Nov 2025 13:47:40 -0800 Subject: [PATCH 11/22] test padding for crop_cvcuda, add functional test --- test/test_transforms_v2.py | 7 +++++++ .../transforms/v2/functional/_geometry.py | 18 ++++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index ca7aea8a79b..b57d46fd446 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3776,6 +3776,13 @@ def test_functional(self): def test_functional_signature(self): check_functional_kernel_signature_match(F.crop, kernel=F.crop_cvcuda, input_type=cvcuda.Tensor) + @pytest.mark.parametrize("size", [(10, 5), (25, 15), (25, 5), (10, 15)]) + def test_functional_correctness(self, size): + image = make_image_cvcuda(TestCrop.INPUT_SIZE, batch_dims=(1,)) + actual = F.crop(image, 0, 0, *size) + expected = F.crop(F.cvcuda_to_tensor(image), 0, 0, *size) + assert_equal(F.cvcuda_to_tensor(actual), expected) + class TestErase: INPUT_SIZE = (17, 11) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index ce3f1b476c0..36aa78ca510 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1937,27 +1937,25 @@ def crop_cvcuda( height_diff = 0 width_diff = 0 if top < 0: - top_diff = -1 * top + top_diff = int(-1 * top) if left < 0: - left_diff = -1 * left + left_diff = int(-1 * left) if top + height > image_height: - height_diff = top + height - image_height + height_diff = int(top + height - image_height) if left + width > image_width: - width_diff = left + width - image_width + width_diff = int(left + width - image_width) if top_diff or left_diff or height_diff or width_diff: image = cvcuda.copymakeborder( image, + border_mode=cvcuda.Border.CONSTANT, + border_value=[0.0] * channels, top=top_diff, left=left_diff, bottom=height_diff, right=width_diff, - border_mode=cvcuda.Border.CONSTANT, - value=[0.0] * channels, ) - top = 0 - left = 0 - height = image_height - width = image_width + top = top + top_diff + left = left + left_diff return cvcuda.customcrop( image, cvcuda.RectI(x=left, y=top, width=width, height=height), From 2219ee5bb8255414d51e22840659a2f59fbeb59d Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 18 Nov 2025 15:27:21 -0800 Subject: [PATCH 12/22] center_crop passes functional equiv --- test/test_transforms_v2.py | 7 +++++++ .../transforms/v2/functional/_geometry.py | 17 +++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index b57d46fd446..85af8ce9c7b 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5077,6 +5077,13 @@ def test_functional(self): def test_functional_signature(self): check_functional_kernel_signature_match(F.center_crop, kernel=F.center_crop_cvcuda, input_type=cvcuda.Tensor) + @pytest.mark.parametrize("output_size", TestCenterCrop.OUTPUT_SIZES) + def test_functional_correctness(self, output_size): + image = make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)) + actual = F.center_crop(image, output_size) + expected = F.center_crop(F.cvcuda_to_tensor(image), output_size) + assert_equal(F.cvcuda_to_tensor(actual), expected) + def test_transform(self): check_transform( transforms.CenterCrop(TestCenterCrop.OUTPUT_SIZES[0]), diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 36aa78ca510..4b4da517854 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2729,16 +2729,25 @@ def center_crop_cvcuda( padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) image = cvcuda.copymakeborder( image, + border_mode=cvcuda.Border.CONSTANT, + border_value=[0.0] * channels, top=padding_ltrb[1], left=padding_ltrb[0], bottom=padding_ltrb[3], right=padding_ltrb[2], - border_mode=cvcuda.Border.CONSTANT, - value=[0.0] * channels, ) - return cvcuda.center_crop( + + image_height = image.shape[1] + image_width = image.shape[2] + + if crop_width == image_width and crop_height == image_height: + return image + + # use customcrop to match crop_image behavior + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) + return cvcuda.customcrop( image, - crop_size=(crop_width, crop_height), + cvcuda.RectI(x=crop_left, y=crop_top, width=crop_width, height=crop_height), ) From ed2bd35d572e1bf6c4dbb755f5c3354bfae37aa5 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 20 Nov 2025 13:18:45 -0800 Subject: [PATCH 13/22] fix: crop testing, adhere to conventions --- test/test_transforms_v2.py | 63 ++++++++++++------- torchvision/transforms/v2/_transform.py | 8 ++- .../transforms/v2/functional/__init__.py | 4 +- .../transforms/v2/functional/_geometry.py | 6 +- .../transforms/v2/functional/_utils.py | 1 + 5 files changed, 51 insertions(+), 31 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 85af8ce9c7b..b740f2b5361 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3506,6 +3506,9 @@ def test_kernel_video(self): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_functional(self, make_input): @@ -3521,6 +3524,11 @@ def test_functional(self, make_input): (F.crop_mask, tv_tensors.Mask), (F.crop_video, tv_tensors.Video), (F.crop_keypoints, tv_tensors.KeyPoints), + pytest.param( + F._geometry._crop_cvcuda, + _import_cvcuda().Tensor, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, kernel, input_type): @@ -3549,15 +3557,18 @@ def test_functional_image_correctness(self, kwargs): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_transform(self, param, value, make_input): - input = make_input(self.INPUT_SIZE) + input_data = make_input(self.INPUT_SIZE) check_sample_input = True if param == "fill": if isinstance(value, (tuple, list)): - if isinstance(input, tv_tensors.Mask): + if isinstance(input_data, tv_tensors.Mask): pytest.skip("F.pad_mask doesn't support non-scalar fill.") else: check_sample_input = False @@ -3566,14 +3577,14 @@ def test_transform(self, param, value, make_input): # 1. size is required # 2. the fill parameter only has an affect if we need padding size=[s + 4 for s in self.INPUT_SIZE], - fill=adapt_fill(value, dtype=input.dtype if isinstance(input, torch.Tensor) else torch.uint8), + fill=adapt_fill(value, dtype=input_data.dtype if isinstance(input_data, torch.Tensor) else torch.uint8), ) else: kwargs = {param: value} check_transform( transforms.RandomCrop(**kwargs, pad_if_needed=True), - input, + input_data, check_v1_compatibility=param != "fill" or isinstance(value, (int, float)), check_sample_input=check_sample_input, ) @@ -3637,6 +3648,31 @@ def test_transform_image_correctness(self, param, value, seed): assert_equal(actual, expected) + @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + @pytest.mark.parametrize("size", [(10, 5), (25, 15), (25, 5), (10, 15), (10, 10)]) + @pytest.mark.parametrize("seed", list(range(5))) + def test_transform_cvcuda_correctness(self, size, seed): + pad_if_needed = False + if size[0] > self.INPUT_SIZE[0] or size[1] > self.INPUT_SIZE[1]: + pad_if_needed = True + transform = transforms.RandomCrop(size, pad_if_needed=pad_if_needed) + + image = make_image(size=self.INPUT_SIZE, batch_dims=(1,), device="cuda") + cv_image = F.to_cvcuda_tensor(image) + + with freeze_rng_state(): + torch.manual_seed(seed) + actual = transform(cv_image) + + torch.manual_seed(seed) + expected = transform(image) + + if not pad_if_needed: + torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=0, atol=0) + else: + # if padding is requied, CV-CUDA will always fill with zeros + torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=0, atol=get_max_value(image.dtype)) + def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width): affine_matrix = np.array( [ @@ -3765,25 +3801,6 @@ def test_errors(self): transforms.RandomCrop([10, 12], padding=1, padding_mode="abc") -@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda not available") -@needs_cuda -class TestCropCVCUDA: - def test_functional(self): - check_functional( - F.crop, make_image_cvcuda(TestCrop.INPUT_SIZE, batch_dims=(1,)), **TestCrop.MINIMAL_CROP_KWARGS - ) - - def test_functional_signature(self): - check_functional_kernel_signature_match(F.crop, kernel=F.crop_cvcuda, input_type=cvcuda.Tensor) - - @pytest.mark.parametrize("size", [(10, 5), (25, 15), (25, 5), (10, 15)]) - def test_functional_correctness(self, size): - image = make_image_cvcuda(TestCrop.INPUT_SIZE, batch_dims=(1,)) - actual = F.crop(image, 0, 0, *size) - expected = F.crop(F.cvcuda_to_tensor(image), 0, 0, *size) - assert_equal(F.cvcuda_to_tensor(actual), expected) - - class TestErase: INPUT_SIZE = (17, 11) FUNCTIONAL_KWARGS = dict( diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..c7b32223b8b 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel +from .functional._utils import _get_kernel, is_cvcuda_tensor class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) def __init__(self) -> None: super().__init__() @@ -90,7 +90,9 @@ def _needs_transform_list(self, flat_inputs: list[Any]) -> list[bool]: # However, this case wasn't supported by transforms v1 either, so there is no BC concern. needs_transform_list = [] - transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image) + transform_pure_tensor = not has_any( + flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image, is_cvcuda_tensor + ) for inpt in flat_inputs: needs_transform = True diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 9b437dfd8a8..52181e4624b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel # usort: skip +from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip from ._meta import ( clamp_bounding_boxes, @@ -76,14 +76,12 @@ affine_video, center_crop, center_crop_bounding_boxes, - center_crop_cvcuda, center_crop_image, center_crop_keypoints, center_crop_mask, center_crop_video, crop, crop_bounding_boxes, - crop_cvcuda, crop_image, crop_keypoints, crop_mask, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4b4da517854..b5d02aa1008 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1924,13 +1924,15 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int return crop_image(video, top, left, height, width) -def crop_cvcuda( +def _crop_cvcuda( image: "cvcuda.Tensor", top: int, left: int, height: int, width: int, ) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + image_height, image_width, channels = image.shape[1:] top_diff = 0 left_diff = 0 @@ -1963,7 +1965,7 @@ def crop_cvcuda( if CVCUDA_AVAILABLE: - _register_kernel_internal(crop, cvcuda.Tensor)(crop_cvcuda) + _crop_cvcuda_registered = _register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda) def perspective( diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 11480b30ef9..6a26b59e592 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -5,6 +5,7 @@ import torch from torchvision import tv_tensors + _FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillTypeJIT = Optional[list[float]] From 3582c58a92e2661b050858c93ec0d37fbf7c02c2 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 20 Nov 2025 13:24:30 -0800 Subject: [PATCH 14/22] Fix: update center crop --- test/test_transforms_v2.py | 49 +++++++++---------- .../transforms/v2/functional/_geometry.py | 6 ++- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index b740f2b5361..c6378976ef4 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -4965,6 +4965,9 @@ def test_kernel_video(self): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_functional(self, make_input): @@ -4980,6 +4983,11 @@ def test_functional(self, make_input): (F.center_crop_mask, tv_tensors.Mask), (F.center_crop_video, tv_tensors.Video), (F.center_crop_keypoints, tv_tensors.KeyPoints), + pytest.param( + F._geometry._center_crop_cvcuda, + _import_cvcuda().Tensor, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, kernel, input_type): @@ -4995,6 +5003,9 @@ def test_functional_signature(self, kernel, input_type): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_transform(self, make_input): @@ -5010,6 +5021,17 @@ def test_image_correctness(self, output_size, fn): assert_equal(actual, expected) + @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + @pytest.mark.parametrize("output_size", OUTPUT_SIZES) + @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)]) + def test_cvcuda_correctness(self, output_size, fn): + image = make_image_cvcuda(self.INPUT_SIZE, dtype=torch.uint8, device="cuda") + + actual = fn(image, output_size) + expected = F.center_crop(F.cvcuda_to_tensor(image), output_size) + + assert_equal(F.cvcuda_to_tensor(actual), expected) + def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size): image_height, image_width = bounding_boxes.canvas_size if isinstance(output_size, int): @@ -5081,33 +5103,6 @@ def test_keypoints_correctness(self, output_size, dtype, device, fn): assert_equal(actual, expected) -@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda not available") -@needs_cuda -class TestCenterCropCVCUDA: - def test_functional(self): - check_functional( - F.center_crop, - make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)), - output_size=TestCenterCrop.OUTPUT_SIZES[0], - ) - - def test_functional_signature(self): - check_functional_kernel_signature_match(F.center_crop, kernel=F.center_crop_cvcuda, input_type=cvcuda.Tensor) - - @pytest.mark.parametrize("output_size", TestCenterCrop.OUTPUT_SIZES) - def test_functional_correctness(self, output_size): - image = make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)) - actual = F.center_crop(image, output_size) - expected = F.center_crop(F.cvcuda_to_tensor(image), output_size) - assert_equal(F.cvcuda_to_tensor(actual), expected) - - def test_transform(self): - check_transform( - transforms.CenterCrop(TestCenterCrop.OUTPUT_SIZES[0]), - make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)), - ) - - class TestPerspective: COEFFICIENTS = [ [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index b5d02aa1008..6b9ceaf93f0 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2718,7 +2718,7 @@ def center_crop_video(video: torch.Tensor, output_size: list[int]) -> torch.Tens return center_crop_image(video, output_size) -def center_crop_cvcuda( +def _center_crop_cvcuda( image: "cvcuda.Tensor", output_size: list[int], ) -> "cvcuda.Tensor": @@ -2754,7 +2754,9 @@ def center_crop_cvcuda( if CVCUDA_AVAILABLE: - _register_kernel_internal(center_crop, cvcuda.Tensor)(center_crop_cvcuda) + _center_crop_cvcuda_registered = _register_kernel_internal(center_crop, _import_cvcuda().Tensor)( + _center_crop_cvcuda + ) def resized_crop( From 18922e3bea4ef343d35228994a8d99013c510409 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 24 Nov 2025 13:42:54 -0800 Subject: [PATCH 15/22] handle some comments from other prs review --- test/test_transforms_v2.py | 8 ++++++-- torchvision/transforms/v2/functional/_geometry.py | 6 ++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index c6378976ef4..b324512e006 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3526,12 +3526,14 @@ def test_functional(self, make_input): (F.crop_keypoints, tv_tensors.KeyPoints), pytest.param( F._geometry._crop_cvcuda, - _import_cvcuda().Tensor, + "cvcuda.Tensor", marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type) @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) @@ -4985,12 +4987,14 @@ def test_functional(self, make_input): (F.center_crop_keypoints, tv_tensors.KeyPoints), pytest.param( F._geometry._center_crop_cvcuda, - _import_cvcuda().Tensor, + "cvcuda.Tensor", marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.center_crop, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 6b9ceaf93f0..5e4ac3ed372 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1965,7 +1965,7 @@ def _crop_cvcuda( if CVCUDA_AVAILABLE: - _crop_cvcuda_registered = _register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda) + _register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda) def perspective( @@ -2754,9 +2754,7 @@ def _center_crop_cvcuda( if CVCUDA_AVAILABLE: - _center_crop_cvcuda_registered = _register_kernel_internal(center_crop, _import_cvcuda().Tensor)( - _center_crop_cvcuda - ) + _register_kernel_internal(center_crop, _import_cvcuda().Tensor)(_center_crop_cvcuda) def resized_crop( From 37a91e0b010a97c99fae98a7b6688b669ae24c57 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Wed, 26 Nov 2025 10:42:37 -0800 Subject: [PATCH 16/22] simplify and improve crop testing for cvcuda --- test/test_transforms_v2.py | 72 +++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index b324512e006..3a7b738b360 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3537,10 +3537,26 @@ def test_functional_signature(self, kernel, input_type): check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type) @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS) - def test_functional_image_correctness(self, kwargs): - image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) + def test_functional_image_correctness(self, kwargs, make_input): + image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") actual = F.crop(image, **kwargs) + + if make_input == make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + actual = actual.squeeze(0) + image = F.cvcuda_to_tensor(image).to(device="cpu") + image = image.squeeze(0) + expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs)) assert_equal(actual, expected) @@ -3628,7 +3644,16 @@ def test_transform_pad_if_needed(self): padding_mode=["constant", "edge", "reflect", "symmetric"], ) @pytest.mark.parametrize("seed", list(range(5))) - def test_transform_image_correctness(self, param, value, seed): + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) + def test_transform_image_correctness(self, param, value, seed, make_input): kwargs = {param: value} if param != "size": # 1. size is required @@ -3639,41 +3664,32 @@ def test_transform_image_correctness(self, param, value, seed): transform = transforms.RandomCrop(pad_if_needed=True, **kwargs) - image = make_image(self.INPUT_SIZE) + will_pad = False + if kwargs["size"][0] > self.INPUT_SIZE[0] or kwargs["size"][1] > self.INPUT_SIZE[1]: + will_pad = True + + image = make_input(self.INPUT_SIZE) with freeze_rng_state(): torch.manual_seed(seed) actual = transform(image) torch.manual_seed(seed) - expected = F.to_image(transform(F.to_pil_image(image))) - assert_equal(actual, expected) + if make_input == make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + actual = actual.squeeze(0) + image = F.cvcuda_to_tensor(image).to(device="cpu") + image = image.squeeze(0) - @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") - @pytest.mark.parametrize("size", [(10, 5), (25, 15), (25, 5), (10, 15), (10, 10)]) - @pytest.mark.parametrize("seed", list(range(5))) - def test_transform_cvcuda_correctness(self, size, seed): - pad_if_needed = False - if size[0] > self.INPUT_SIZE[0] or size[1] > self.INPUT_SIZE[1]: - pad_if_needed = True - transform = transforms.RandomCrop(size, pad_if_needed=pad_if_needed) - - image = make_image(size=self.INPUT_SIZE, batch_dims=(1,), device="cuda") - cv_image = F.to_cvcuda_tensor(image) - - with freeze_rng_state(): - torch.manual_seed(seed) - actual = transform(cv_image) - - torch.manual_seed(seed) - expected = transform(image) + expected = F.to_image(transform(F.to_pil_image(image))) - if not pad_if_needed: - torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=0, atol=0) + if make_input == make_image_cvcuda and will_pad: + # when padding is applied, CV-CUDA will always fill with zeros + # cannot use assert_equal since it will fail unless random is all zeros + torch.testing.assert_close(actual, expected, rtol=0, atol=get_max_value(image.dtype)) else: - # if padding is requied, CV-CUDA will always fill with zeros - torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=0, atol=get_max_value(image.dtype)) + assert_equal(actual, expected) def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width): affine_matrix = np.array( From 9b721ef9b37d689c59cc75bcaac7a50bdecb8e94 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Wed, 26 Nov 2025 10:45:24 -0800 Subject: [PATCH 17/22] simplify test for center crop cvcuda --- test/test_transforms_v2.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3a7b738b360..33365dfdf70 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5032,25 +5032,30 @@ def test_transform(self, make_input): check_transform(transforms.CenterCrop(self.OUTPUT_SIZES[0]), make_input(self.INPUT_SIZE)) @pytest.mark.parametrize("output_size", OUTPUT_SIZES) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)]) - def test_image_correctness(self, output_size, fn): - image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + def test_image_correctness(self, output_size, make_input, fn): + image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") actual = fn(image, output_size) - expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size)) - assert_equal(actual, expected) - - @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") - @pytest.mark.parametrize("output_size", OUTPUT_SIZES) - @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)]) - def test_cvcuda_correctness(self, output_size, fn): - image = make_image_cvcuda(self.INPUT_SIZE, dtype=torch.uint8, device="cuda") + if make_input == make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + actual = actual.squeeze(0) + image = F.cvcuda_to_tensor(image).to(device="cpu") + image = image.squeeze(0) - actual = fn(image, output_size) - expected = F.center_crop(F.cvcuda_to_tensor(image), output_size) + expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size)) - assert_equal(F.cvcuda_to_tensor(actual), expected) + assert_equal(actual, expected) def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size): image_height, image_width = bounding_boxes.canvas_size From 6a0035dfcc01eb82aefc55aad5e06df57622d2c5 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 15:51:41 -0800 Subject: [PATCH 18/22] begin work on finalizing the crop PR to include five and ten crop, adhere to new PR reviews for flip --- test/common_utils.py | 11 +++ test/test_transforms_v2.py | 78 +++++++++++++++---- torchvision/transforms/v2/_geometry.py | 11 +++ torchvision/transforms/v2/_transform.py | 6 +- .../transforms/v2/functional/_geometry.py | 71 +++++++++++++++++ 5 files changed, 160 insertions(+), 17 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index e3fa464b5ea..a841805572d 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -276,6 +276,17 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] +def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: + tensor = cvcuda_to_tensor(tensor) + if tensor.ndim != 4: + raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") + if tensor.shape[0] != 1: + raise ValueError( + f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." + ) + return tensor.squeeze(0).cpu() + + class ImagePair(TensorLikePair): def __init__( self, diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 33365dfdf70..a034fec3e72 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3677,10 +3677,7 @@ def test_transform_image_correctness(self, param, value, seed, make_input): torch.manual_seed(seed) if make_input == make_image_cvcuda: - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - actual = actual.squeeze(0) - image = F.cvcuda_to_tensor(image).to(device="cpu") - image = image.squeeze(0) + image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(transform(F.to_pil_image(image))) @@ -5048,10 +5045,7 @@ def test_image_correctness(self, output_size, make_input, fn): actual = fn(image, output_size) if make_input == make_image_cvcuda: - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - actual = actual.squeeze(0) - image = F.cvcuda_to_tensor(image).to(device="cpu") - image = image.squeeze(0) + image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size)) @@ -6327,7 +6321,15 @@ def wrapper(*args, **kwargs): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], ) @pytest.mark.parametrize("functional", [F.five_crop, F.ten_crop]) def test_functional(self, make_input, functional): @@ -6345,13 +6347,27 @@ def test_functional(self, make_input, functional): (F.five_crop, F._geometry._five_crop_image_pil, PIL.Image.Image), (F.five_crop, F.five_crop_image, tv_tensors.Image), (F.five_crop, F.five_crop_video, tv_tensors.Video), + pytest.param( + F.five_crop, + F._geometry._five_crop_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), (F.ten_crop, F.ten_crop_image, torch.Tensor), (F.ten_crop, F._geometry._ten_crop_image_pil, PIL.Image.Image), (F.ten_crop, F.ten_crop_image, tv_tensors.Image), (F.ten_crop, F.ten_crop_video, tv_tensors.Video), + pytest.param( + F.ten_crop, + F._geometry._ten_crop_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, functional, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(functional, kernel=kernel, input_type=input_type) class _TransformWrapper(nn.Module): @@ -6373,7 +6389,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], ) @pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop]) def test_transform(self, make_input, transform_cls): @@ -6391,19 +6415,41 @@ def test_transform_error(self, make_input, transform_cls): with pytest.raises(TypeError, match="not supported"): transform(make_input(self.INPUT_SIZE)) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("fn", [F.five_crop, transform_cls_to_functional(transforms.FiveCrop)]) - def test_correctness_image_five_crop(self, fn): - image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + def test_correctness_image_five_crop(self, make_input, fn): + image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") actual = fn(image, size=self.OUTPUT_SIZE) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.five_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE) assert isinstance(actual, tuple) assert_equal(actual, [F.to_image(e) for e in expected]) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("fn_or_class", [F.ten_crop, transforms.TenCrop]) @pytest.mark.parametrize("vertical_flip", [False, True]) - def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip): + def test_correctness_image_ten_crop(self, make_input, fn_or_class, vertical_flip): if fn_or_class is transforms.TenCrop: fn = transform_cls_to_functional(fn_or_class, size=self.OUTPUT_SIZE, vertical_flip=vertical_flip) kwargs = dict() @@ -6411,9 +6457,13 @@ def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip): fn = fn_or_class kwargs = dict(size=self.OUTPUT_SIZE, vertical_flip=vertical_flip) - image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") + image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu") actual = fn(image, **kwargs) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.ten_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE, vertical_flip=vertical_flip) assert isinstance(actual, tuple) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 96166e05e9a..6eb0214998a 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -26,6 +26,7 @@ get_bounding_boxes, has_all, has_any, + is_cvcuda_tensor, is_pure_tensor, query_size, ) @@ -194,6 +195,8 @@ class CenterCrop(Transform): _v1_transform_cls = _transforms.CenterCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__(self, size: Union[int, Sequence[int]]): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -360,6 +363,8 @@ class FiveCrop(Transform): _v1_transform_cls = _transforms.FiveCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -404,6 +409,8 @@ class TenCrop(Transform): _v1_transform_cls = _transforms.TenCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -811,6 +818,8 @@ class RandomCrop(Transform): _v1_transform_cls = _transforms.RandomCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def _extract_params_for_v1_transform(self) -> dict[str, Any]: params = super()._extract_params_for_v1_transform() @@ -1121,6 +1130,8 @@ class RandomIoUCrop(Transform): Default, 40. """ + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__( self, min_scale: float = 0.3, diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index c7b32223b8b..b0985bb0aec 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -8,10 +8,10 @@ from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import tv_tensors -from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor +from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel, is_cvcuda_tensor +from .functional._utils import _get_kernel class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) def __init__(self) -> None: super().__init__() diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 5e4ac3ed372..6bf4fa5940a 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -147,6 +147,14 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: return horizontal_flip_image(video) +def _horizontal_flip_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + return _import_cvcuda().flip(image, flipCode=1) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_cvcuda) + + def vertical_flip(inpt: torch.Tensor) -> torch.Tensor: """See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details.""" if torch.jit.is_scripting(): @@ -243,6 +251,14 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: return vertical_flip_image(video) +def _vertical_flip_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + return _import_cvcuda().flip(image, flipCode=0) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_cvcuda) + + # We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are # prevalent and well understood. Thus, we just alias them without deprecating the old names. hflip = horizontal_flip @@ -3016,6 +3032,29 @@ def five_crop_video( return five_crop_image(video, size) +def _five_crop_cvcuda( + image: "cvcuda.Tensor", + size: list[int], +) -> tuple["cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor"]: + crop_height, crop_width = _parse_five_crop_size(size) + image_height, image_width = image.shape[-2:] + + if crop_width > image_width or crop_height > image_height: + raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") + + tl = _crop_cvcuda(image, 0, 0, crop_height, crop_width) + tr = _crop_cvcuda(image, 0, image_width - crop_height, crop_width, crop_height) + bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_width, crop_height) + br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_width, crop_height) + center = _center_crop_cvcuda(image, [crop_height, crop_width]) + + return tl, tr, bl, br, center + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(five_crop, _import_cvcuda().Tensor)(_five_crop_cvcuda) + + def ten_crop( inpt: torch.Tensor, size: list[int], vertical_flip: bool = False ) -> tuple[ @@ -3111,3 +3150,35 @@ def ten_crop_video( torch.Tensor, ]: return ten_crop_image(video, size, vertical_flip=vertical_flip) + + +def _ten_crop_cvcuda( + image: "cvcuda.Tensor", + size: list[int], + vertical_flip: bool = False, +) -> tuple[ + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", + "cvcuda.Tensor", +]: + non_flipped = _five_crop_cvcuda(image, size) + + if vertical_flip: + image = _vertical_flip_cvcuda(image) + else: + image = _horizontal_flip_cvcuda(image) + + flipped = _five_crop_cvcuda(image, size) + + return non_flipped + flipped + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(ten_crop, _import_cvcuda().Tensor)(_ten_crop_cvcuda) From e287fc136068caad3c9b5c3783df587c2bb3cfca Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 16:54:47 -0800 Subject: [PATCH 19/22] update to include five ten crop and resized crop, use placeholder transforms for flip and resize for now --- test/test_transforms_v2.py | 49 +++++++++++++---- torchvision/transforms/v2/_geometry.py | 2 + .../transforms/v2/functional/_geometry.py | 52 +++++++++++++++++-- 3 files changed, 88 insertions(+), 15 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index a034fec3e72..5ee1d082bed 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3551,11 +3551,8 @@ def test_functional_image_correctness(self, kwargs, make_input): actual = F.crop(image, **kwargs) - if make_input == make_image_cvcuda: - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - actual = actual.squeeze(0) - image = F.cvcuda_to_tensor(image).to(device="cpu") - image = image.squeeze(0) + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs)) @@ -3676,7 +3673,7 @@ def test_transform_image_correctness(self, param, value, seed, make_input): torch.manual_seed(seed) - if make_input == make_image_cvcuda: + if make_input is make_image_cvcuda: image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(transform(F.to_pil_image(image))) @@ -3684,7 +3681,7 @@ def test_transform_image_correctness(self, param, value, seed, make_input): if make_input == make_image_cvcuda and will_pad: # when padding is applied, CV-CUDA will always fill with zeros # cannot use assert_equal since it will fail unless random is all zeros - torch.testing.assert_close(actual, expected, rtol=0, atol=get_max_value(image.dtype)) + assert_close(actual, expected, rtol=0, atol=get_max_value(image.dtype)) else: assert_equal(actual, expected) @@ -4510,6 +4507,9 @@ def test_kernel(self, kernel, make_input): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_functional(self, make_input): @@ -4526,9 +4526,16 @@ def test_functional(self, make_input): (F.resized_crop_mask, tv_tensors.Mask), (F.resized_crop_video, tv_tensors.Video), (F.resized_crop_keypoints, tv_tensors.KeyPoints), + pytest.param( + F.resized_crop_image, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.resized_crop, kernel=kernel, input_type=input_type) @param_value_parametrization( @@ -4545,6 +4552,9 @@ def test_functional_signature(self, kernel, input_type): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), ], ) def test_transform(self, param, value, make_input): @@ -4556,20 +4566,37 @@ def test_transform(self, param, value, make_input): # `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2. # The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT` + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST}) - def test_functional_image_correctness(self, interpolation): - image = make_image(self.INPUT_SIZE, dtype=torch.uint8) + def test_functional_image_correctness(self, make_input, interpolation): + image = make_input(self.INPUT_SIZE, dtype=torch.uint8) actual = F.resized_crop( image, **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation, antialias=True ) + + if make_input is make_image_cvcuda: + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.to_image( F.resized_crop( F.to_pil_image(image), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation ) ) - torch.testing.assert_close(actual, expected, atol=1, rtol=0) + atol = 1 + if make_input is make_image_cvcuda and interpolation == transforms.InterpolationMode.BICUBIC: + # CV-CUDA BICUBIC differs from PIL ground truth BICUBIC + atol = 10 + assert_close(actual, expected, atol=atol, rtol=0) def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width, size): new_height, new_width = size @@ -5044,7 +5071,7 @@ def test_image_correctness(self, output_size, make_input, fn): actual = fn(image, output_size) - if make_input == make_image_cvcuda: + if make_input is make_image_cvcuda: image = cvcuda_to_pil_compatible_tensor(image) expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size)) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 6eb0214998a..56efb3525e7 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -255,6 +255,8 @@ class RandomResizedCrop(Transform): _v1_transform_cls = _transforms.RandomResizedCrop + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__( self, size: Union[int, Sequence[int]], diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 6bf4fa5940a..af67298b0f7 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -618,6 +618,32 @@ def resize_video( return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) +def _resize_cvcuda( + image: "cvcuda.Tensor", + size: Optional[list[int]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, +) -> "cvcuda.Tensor": + # placeholder func for now, will be handled in PR for resize alone + # since placeholder convert to from torch tensor and use resize_image + from ._type_conversion import cvcuda_to_tensor, to_cvcuda_tensor + + return to_cvcuda_tensor( + resize_image( + cvcuda_to_tensor(image), + size=size, + interpolation=interpolation, + max_size=max_size, + antialias=antialias, + ) + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(resize, _import_cvcuda().Tensor)(_resize_cvcuda) + + def affine( inpt: torch.Tensor, angle: Union[int, float], @@ -2959,6 +2985,24 @@ def resized_crop_video( ) +def _resized_crop_cvcuda( + image: "cvcuda.Tensor", + top: int, + left: int, + height: int, + width: int, + size: list[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, +) -> "cvcuda.Tensor": + image = _crop_cvcuda(image, top, left, height, width) + return _resize_cvcuda(image, size, interpolation=interpolation, antialias=antialias) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(resized_crop, _import_cvcuda().Tensor)(_resized_crop_cvcuda) + + def five_crop( inpt: torch.Tensor, size: list[int] ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -3037,15 +3081,15 @@ def _five_crop_cvcuda( size: list[int], ) -> tuple["cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor"]: crop_height, crop_width = _parse_five_crop_size(size) - image_height, image_width = image.shape[-2:] + image_height, image_width = image.shape[1], image.shape[2] if crop_width > image_width or crop_height > image_height: raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") tl = _crop_cvcuda(image, 0, 0, crop_height, crop_width) - tr = _crop_cvcuda(image, 0, image_width - crop_height, crop_width, crop_height) - bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_width, crop_height) - br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_width, crop_height) + tr = _crop_cvcuda(image, 0, image_width - crop_width, crop_height, crop_width) + bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_height, crop_width) + br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) center = _center_crop_cvcuda(image, [crop_height, crop_width]) return tl, tr, bl, br, center From 540551aafc0c14f923c9d965a26ca477fa65cd01 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:50:56 -0800 Subject: [PATCH 20/22] update crop to new main standards --- test/common_utils.py | 11 ---- test/test_transforms_v2.py | 23 ++++--- torchvision/transforms/v2/_geometry.py | 25 ++++++-- torchvision/transforms/v2/_transform.py | 7 +- .../transforms/v2/functional/__init__.py | 2 +- .../transforms/v2/functional/_geometry.py | 64 ++++++++----------- 6 files changed, 60 insertions(+), 72 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index a841805572d..e3fa464b5ea 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -276,17 +276,6 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] -def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: - tensor = cvcuda_to_tensor(tensor) - if tensor.ndim != 4: - raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") - if tensor.shape[0] != 1: - raise ValueError( - f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." - ) - return tensor.squeeze(0).cpu() - - class ImagePair(TensorLikePair): def __init__( self, diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 5ee1d082bed..e5563b28377 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -25,7 +25,6 @@ assert_equal, cache, cpu_and_cuda, - cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, @@ -3525,7 +3524,7 @@ def test_functional(self, make_input): (F.crop_video, tv_tensors.Video), (F.crop_keypoints, tv_tensors.KeyPoints), pytest.param( - F._geometry._crop_cvcuda, + F._geometry._crop_image_cvcuda, "cvcuda.Tensor", marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), @@ -3552,7 +3551,7 @@ def test_functional_image_correctness(self, kwargs, make_input): actual = F.crop(image, **kwargs) if make_input is make_image_cvcuda: - image = cvcuda_to_pil_compatible_tensor(image) + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs)) @@ -3674,7 +3673,7 @@ def test_transform_image_correctness(self, param, value, seed, make_input): torch.manual_seed(seed) if make_input is make_image_cvcuda: - image = cvcuda_to_pil_compatible_tensor(image) + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.to_image(transform(F.to_pil_image(image))) @@ -4527,7 +4526,7 @@ def test_functional(self, make_input): (F.resized_crop_video, tv_tensors.Video), (F.resized_crop_keypoints, tv_tensors.KeyPoints), pytest.param( - F.resized_crop_image, + F._geometry._resized_crop_image_cvcuda, "cvcuda.Tensor", marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), @@ -4584,7 +4583,7 @@ def test_functional_image_correctness(self, make_input, interpolation): ) if make_input is make_image_cvcuda: - image = cvcuda_to_pil_compatible_tensor(image) + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.to_image( F.resized_crop( @@ -5026,7 +5025,7 @@ def test_functional(self, make_input): (F.center_crop_video, tv_tensors.Video), (F.center_crop_keypoints, tv_tensors.KeyPoints), pytest.param( - F._geometry._center_crop_cvcuda, + F._geometry._center_crop_image_cvcuda, "cvcuda.Tensor", marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), @@ -5072,7 +5071,7 @@ def test_image_correctness(self, output_size, make_input, fn): actual = fn(image, output_size) if make_input is make_image_cvcuda: - image = cvcuda_to_pil_compatible_tensor(image) + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size)) @@ -6376,7 +6375,7 @@ def test_functional(self, make_input, functional): (F.five_crop, F.five_crop_video, tv_tensors.Video), pytest.param( F.five_crop, - F._geometry._five_crop_cvcuda, + F._geometry._five_crop_image_cvcuda, "cvcuda.Tensor", marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), @@ -6386,7 +6385,7 @@ def test_functional(self, make_input, functional): (F.ten_crop, F.ten_crop_video, tv_tensors.Video), pytest.param( F.ten_crop, - F._geometry._ten_crop_cvcuda, + F._geometry._ten_crop_image_cvcuda, "cvcuda.Tensor", marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), @@ -6458,7 +6457,7 @@ def test_correctness_image_five_crop(self, make_input, fn): actual = fn(image, size=self.OUTPUT_SIZE) if make_input is make_image_cvcuda: - image = cvcuda_to_pil_compatible_tensor(image) + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.five_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE) @@ -6489,7 +6488,7 @@ def test_correctness_image_ten_crop(self, make_input, fn_or_class, vertical_flip actual = fn(image, **kwargs) if make_input is make_image_cvcuda: - image = cvcuda_to_pil_compatible_tensor(image) + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.ten_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE, vertical_flip=vertical_flip) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 56efb3525e7..6888e6d41f4 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -26,7 +26,6 @@ get_bounding_boxes, has_all, has_any, - is_cvcuda_tensor, is_pure_tensor, query_size, ) @@ -140,6 +139,9 @@ class Resize(Transform): _v1_transform_cls = _transforms.Resize + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def __init__( self, size: Union[int, Sequence[int], None], @@ -195,7 +197,8 @@ class CenterCrop(Transform): _v1_transform_cls = _transforms.CenterCrop - _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) def __init__(self, size: Union[int, Sequence[int]]): super().__init__() @@ -255,7 +258,8 @@ class RandomResizedCrop(Transform): _v1_transform_cls = _transforms.RandomResizedCrop - _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) def __init__( self, @@ -365,7 +369,8 @@ class FiveCrop(Transform): _v1_transform_cls = _transforms.FiveCrop - _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() @@ -411,7 +416,8 @@ class TenCrop(Transform): _v1_transform_cls = _transforms.TenCrop - _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: super().__init__() @@ -820,7 +826,8 @@ class RandomCrop(Transform): _v1_transform_cls = _transforms.RandomCrop - _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) def _extract_params_for_v1_transform(self) -> dict[str, Any]: params = super()._extract_params_for_v1_transform() @@ -1132,7 +1139,8 @@ class RandomIoUCrop(Transform): Default, 40. """ - _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) def __init__( self, @@ -1415,6 +1423,9 @@ class RandomResize(Transform): v0.17, for the PIL and Tensor backends to be consistent. """ + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def __init__( self, min_size: int, diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index b0985bb0aec..7fb6644032a 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -8,7 +8,7 @@ from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import tv_tensors -from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor +from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once from .functional._utils import _get_kernel @@ -91,7 +91,10 @@ def _needs_transform_list(self, flat_inputs: list[Any]) -> list[bool]: needs_transform_list = [] transform_pure_tensor = not has_any( - flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image, is_cvcuda_tensor + flat_inputs, + tv_tensors.Image, + tv_tensors.Video, + PIL.Image.Image, ) for inpt in flat_inputs: needs_transform = True diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 52181e4624b..032a993b1f0 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip +from ._utils import is_pure_tensor, register_kernel # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index af67298b0f7..a78f1c33f2e 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -147,14 +147,6 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: return horizontal_flip_image(video) -def _horizontal_flip_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": - return _import_cvcuda().flip(image, flipCode=1) - - -if CVCUDA_AVAILABLE: - _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_cvcuda) - - def vertical_flip(inpt: torch.Tensor) -> torch.Tensor: """See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details.""" if torch.jit.is_scripting(): @@ -251,14 +243,6 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: return vertical_flip_image(video) -def _vertical_flip_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": - return _import_cvcuda().flip(image, flipCode=0) - - -if CVCUDA_AVAILABLE: - _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_cvcuda) - - # We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are # prevalent and well understood. Thus, we just alias them without deprecating the old names. hflip = horizontal_flip @@ -618,7 +602,7 @@ def resize_video( return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) -def _resize_cvcuda( +def _resize_image_cvcuda( image: "cvcuda.Tensor", size: Optional[list[int]], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, @@ -641,7 +625,7 @@ def _resize_cvcuda( if CVCUDA_AVAILABLE: - _register_kernel_internal(resize, _import_cvcuda().Tensor)(_resize_cvcuda) + _register_kernel_internal(resize, _import_cvcuda().Tensor)(_resize_image_cvcuda) def affine( @@ -1966,7 +1950,7 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int return crop_image(video, top, left, height, width) -def _crop_cvcuda( +def _crop_image_cvcuda( image: "cvcuda.Tensor", top: int, left: int, @@ -2007,7 +1991,7 @@ def _crop_cvcuda( if CVCUDA_AVAILABLE: - _register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda) + _register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_image_cvcuda) def perspective( @@ -2760,10 +2744,12 @@ def center_crop_video(video: torch.Tensor, output_size: list[int]) -> torch.Tens return center_crop_image(video, output_size) -def _center_crop_cvcuda( +def _center_crop_image_cvcuda( image: "cvcuda.Tensor", output_size: list[int], ) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + crop_height, crop_width = _center_crop_parse_output_size(output_size) # we only allow cvcuda conversion for 4 ndim, and always use nhwc layout image_height = image.shape[1] @@ -2796,7 +2782,7 @@ def _center_crop_cvcuda( if CVCUDA_AVAILABLE: - _register_kernel_internal(center_crop, _import_cvcuda().Tensor)(_center_crop_cvcuda) + _register_kernel_internal(center_crop, _import_cvcuda().Tensor)(_center_crop_image_cvcuda) def resized_crop( @@ -2985,7 +2971,7 @@ def resized_crop_video( ) -def _resized_crop_cvcuda( +def _resized_crop_image_cvcuda( image: "cvcuda.Tensor", top: int, left: int, @@ -2995,12 +2981,12 @@ def _resized_crop_cvcuda( interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[bool] = True, ) -> "cvcuda.Tensor": - image = _crop_cvcuda(image, top, left, height, width) - return _resize_cvcuda(image, size, interpolation=interpolation, antialias=antialias) + image = _crop_image_cvcuda(image, top, left, height, width) + return _resize_image_cvcuda(image, size, interpolation=interpolation, antialias=antialias) if CVCUDA_AVAILABLE: - _register_kernel_internal(resized_crop, _import_cvcuda().Tensor)(_resized_crop_cvcuda) + _register_kernel_internal(resized_crop, _import_cvcuda().Tensor)(_resized_crop_image_cvcuda) def five_crop( @@ -3076,7 +3062,7 @@ def five_crop_video( return five_crop_image(video, size) -def _five_crop_cvcuda( +def _five_crop_image_cvcuda( image: "cvcuda.Tensor", size: list[int], ) -> tuple["cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor"]: @@ -3086,17 +3072,17 @@ def _five_crop_cvcuda( if crop_width > image_width or crop_height > image_height: raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") - tl = _crop_cvcuda(image, 0, 0, crop_height, crop_width) - tr = _crop_cvcuda(image, 0, image_width - crop_width, crop_height, crop_width) - bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_height, crop_width) - br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) - center = _center_crop_cvcuda(image, [crop_height, crop_width]) + tl = _crop_image_cvcuda(image, 0, 0, crop_height, crop_width) + tr = _crop_image_cvcuda(image, 0, image_width - crop_width, crop_height, crop_width) + bl = _crop_image_cvcuda(image, image_height - crop_height, 0, crop_height, crop_width) + br = _crop_image_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + center = _center_crop_image_cvcuda(image, [crop_height, crop_width]) return tl, tr, bl, br, center if CVCUDA_AVAILABLE: - _register_kernel_internal(five_crop, _import_cvcuda().Tensor)(_five_crop_cvcuda) + _register_kernel_internal(five_crop, _import_cvcuda().Tensor)(_five_crop_image_cvcuda) def ten_crop( @@ -3196,7 +3182,7 @@ def ten_crop_video( return ten_crop_image(video, size, vertical_flip=vertical_flip) -def _ten_crop_cvcuda( +def _ten_crop_image_cvcuda( image: "cvcuda.Tensor", size: list[int], vertical_flip: bool = False, @@ -3212,17 +3198,17 @@ def _ten_crop_cvcuda( "cvcuda.Tensor", "cvcuda.Tensor", ]: - non_flipped = _five_crop_cvcuda(image, size) + non_flipped = _five_crop_image_cvcuda(image, size) if vertical_flip: - image = _vertical_flip_cvcuda(image) + image = _vertical_flip_image_cvcuda(image) else: - image = _horizontal_flip_cvcuda(image) + image = _horizontal_flip_image_cvcuda(image) - flipped = _five_crop_cvcuda(image, size) + flipped = _five_crop_image_cvcuda(image, size) return non_flipped + flipped if CVCUDA_AVAILABLE: - _register_kernel_internal(ten_crop, _import_cvcuda().Tensor)(_ten_crop_cvcuda) + _register_kernel_internal(ten_crop, _import_cvcuda().Tensor)(_ten_crop_image_cvcuda) From 62877ca10399078659f62fd26e817ebb711d1b8f Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:53:25 -0800 Subject: [PATCH 21/22] reduce diff --- torchvision/transforms/v2/_transform.py | 7 +------ torchvision/transforms/v2/functional/_augment.py | 11 +---------- torchvision/transforms/v2/functional/_color.py | 12 +----------- torchvision/transforms/v2/functional/_misc.py | 11 ++--------- torchvision/transforms/v2/functional/_utils.py | 1 - 5 files changed, 5 insertions(+), 37 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index 7fb6644032a..ac84fcb6c82 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -90,12 +90,7 @@ def _needs_transform_list(self, flat_inputs: list[Any]) -> list[bool]: # However, this case wasn't supported by transforms v1 either, so there is no BC concern. needs_transform_list = [] - transform_pure_tensor = not has_any( - flat_inputs, - tv_tensors.Image, - tv_tensors.Video, - PIL.Image.Image, - ) + transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image) for inpt in flat_inputs: needs_transform = True diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 7ce5bdc7b7e..a904d8d7cbd 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,5 +1,4 @@ import io -from typing import TYPE_CHECKING import PIL.Image @@ -9,15 +8,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal - - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal def erase( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 5be9c62902a..be254c0d63a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,5 +1,3 @@ -from typing import TYPE_CHECKING - import PIL.Image import torch from torch.nn.functional import conv2d @@ -11,15 +9,7 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal - - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 0fa05a2113c..daf263df046 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional, TYPE_CHECKING +from typing import Optional import PIL.Image import torch @@ -13,14 +13,7 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor def normalize( diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 6a26b59e592..11480b30ef9 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -5,7 +5,6 @@ import torch from torchvision import tv_tensors - _FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillTypeJIT = Optional[list[float]] From c2964f86fe59b03ad5352a4534aa6e75cb72450c Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 13:46:26 -0800 Subject: [PATCH 22/22] check input type on kernel for signature test --- test/test_transforms_v2.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index e5563b28377..e7b003f2014 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3525,13 +3525,13 @@ def test_functional(self, make_input): (F.crop_keypoints, tv_tensors.KeyPoints), pytest.param( F._geometry._crop_image_cvcuda, - "cvcuda.Tensor", + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), ], ) def test_functional_signature(self, kernel, input_type): - if input_type == "cvcuda.Tensor": + if kernel is F._geometry._crop_image_cvcuda: input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type) @@ -4527,13 +4527,13 @@ def test_functional(self, make_input): (F.resized_crop_keypoints, tv_tensors.KeyPoints), pytest.param( F._geometry._resized_crop_image_cvcuda, - "cvcuda.Tensor", + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), ], ) def test_functional_signature(self, kernel, input_type): - if input_type == "cvcuda.Tensor": + if kernel is F._geometry._resized_crop_image_cvcuda: input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.resized_crop, kernel=kernel, input_type=input_type) @@ -5026,13 +5026,13 @@ def test_functional(self, make_input): (F.center_crop_keypoints, tv_tensors.KeyPoints), pytest.param( F._geometry._center_crop_image_cvcuda, - "cvcuda.Tensor", + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), ], ) def test_functional_signature(self, kernel, input_type): - if input_type == "cvcuda.Tensor": + if kernel is F._geometry._center_crop_image_cvcuda: input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.center_crop, kernel=kernel, input_type=input_type) @@ -6376,7 +6376,7 @@ def test_functional(self, make_input, functional): pytest.param( F.five_crop, F._geometry._five_crop_image_cvcuda, - "cvcuda.Tensor", + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), (F.ten_crop, F.ten_crop_image, torch.Tensor), @@ -6386,13 +6386,13 @@ def test_functional(self, make_input, functional): pytest.param( F.ten_crop, F._geometry._ten_crop_image_cvcuda, - "cvcuda.Tensor", + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), ], ) def test_functional_signature(self, functional, kernel, input_type): - if input_type == "cvcuda.Tensor": + if kernel is F._geometry._five_crop_image_cvcuda or kernel is F._geometry._ten_crop_image_cvcuda: input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(functional, kernel=kernel, input_type=input_type)