From 98616f4fd1ecd971580412487451c7244cc77a21 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Mon, 24 Nov 2025 15:18:19 -0800 Subject: [PATCH 1/8] Update CVCUDA tests for horizontal and vertical flip and make changes according to the comments --- test/test_transforms_v2.py | 103 +++++++++++++++--- torchvision/transforms/v2/_geometry.py | 12 +- .../transforms/v2/functional/_geometry.py | 24 +++- 3 files changed, 119 insertions(+), 20 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 670a9d00ffb..d7684d7c9a2 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1240,6 +1240,10 @@ def test_kernel_video(self): make_image_tensor, make_image_pil, make_image, + pytest.param( + functools.partial(make_image_cvcuda, batch_dims=(1,)), + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1255,6 +1259,11 @@ def test_functional(self, make_input): (F.horizontal_flip_image, torch.Tensor), (F._geometry._horizontal_flip_image_pil, PIL.Image.Image), (F.horizontal_flip_image, tv_tensors.Image), + pytest.param( + F._geometry._horizontal_flip_image_cvcuda, + cvcuda.Tensor, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), (F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes), (F.horizontal_flip_mask, tv_tensors.Mask), (F.horizontal_flip_video, tv_tensors.Video), @@ -1270,6 +1279,10 @@ def test_functional_signature(self, kernel, input_type): make_image_tensor, make_image_pil, make_image, + pytest.param( + functools.partial(make_image_cvcuda, batch_dims=(1,)), + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1283,13 +1296,32 @@ def test_transform(self, make_input, device): @pytest.mark.parametrize( "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] ) - def test_image_correctness(self, fn): - image = make_image(dtype=torch.uint8, device="cpu") - actual = fn(image) - expected = F.to_image(F.horizontal_flip(F.to_pil_image(image))) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + functools.partial(make_image_cvcuda, batch_dims=(1,)), + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), + ], + ) - torch.testing.assert_close(actual, expected) + def test_image_correctness(self, fn, make_input): + image = make_input() + actual = fn(image) + if isinstance(image, cvcuda.Tensor): + # For CVCUDA input + expected = F.horizontal_flip(F.cvcuda_to_tensor(image)) + print("actual is ", F.cvcuda_to_tensor(actual)) + print("expected is ", expected) + assert_equal(F.cvcuda_to_tensor(actual), expected) + + else: + # For PIL/regular image input + expected = F.to_image(F.horizontal_flip(F.to_pil_image(image))) + assert_equal(actual, expected) def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes): affine_matrix = np.array( @@ -1345,6 +1377,10 @@ def test_keypoints_correctness(self, fn): make_image_tensor, make_image_pil, make_image, + pytest.param( + functools.partial(make_image_cvcuda, batch_dims=(1,)), + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1354,12 +1390,13 @@ def test_keypoints_correctness(self, fn): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform_noop(self, make_input, device): input = make_input(device=device) - transform = transforms.RandomHorizontalFlip(p=0) - output = transform(input) + if isinstance(input, cvcuda.Tensor): + assert_equal(F.cvcuda_to_tensor(output), F.cvcuda_to_tensor(input)) + else: + assert_equal(output, input) - assert_equal(output, input) class TestAffine: @@ -1856,6 +1893,10 @@ def test_kernel_video(self): make_image_tensor, make_image_pil, make_image, + pytest.param( + functools.partial(make_image_cvcuda, batch_dims=(1,)), + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1871,6 +1912,11 @@ def test_functional(self, make_input): (F.vertical_flip_image, torch.Tensor), (F._geometry._vertical_flip_image_pil, PIL.Image.Image), (F.vertical_flip_image, tv_tensors.Image), + pytest.param( + F._geometry._vertical_flip_image_cvcuda, + cvcuda.Tensor, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), (F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes), (F.vertical_flip_mask, tv_tensors.Mask), (F.vertical_flip_video, tv_tensors.Video), @@ -1886,6 +1932,10 @@ def test_functional_signature(self, kernel, input_type): make_image_tensor, make_image_pil, make_image, + pytest.param( + functools.partial(make_image_cvcuda, batch_dims=(1,)), + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1897,13 +1947,28 @@ def test_transform(self, make_input, device): check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device)) @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) - def test_image_correctness(self, fn): - image = make_image(dtype=torch.uint8, device="cpu") + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + functools.partial(make_image_cvcuda, batch_dims=(1,)), + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), + ], + ) + def test_image_correctness(self, fn, make_input): + image = make_input() actual = fn(image) - expected = F.to_image(F.vertical_flip(F.to_pil_image(image))) - - torch.testing.assert_close(actual, expected) + if isinstance(image, cvcuda.Tensor): + # For CVCUDA input + expected = F.vertical_flip(F.cvcuda_to_tensor(image)) + assert_equal(F.cvcuda_to_tensor(actual), expected) + else: + # For PIL/regular image input + expected = F.to_image(F.vertical_flip(F.to_pil_image(image))) + assert_equal(actual, expected) def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes): affine_matrix = np.array( @@ -1955,6 +2020,10 @@ def test_keypoints_correctness(self, fn): make_image_tensor, make_image_pil, make_image, + pytest.param( + functools.partial(make_image_cvcuda, batch_dims=(1,)), + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1964,12 +2033,12 @@ def test_keypoints_correctness(self, fn): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform_noop(self, make_input, device): input = make_input(device=device) - transform = transforms.RandomVerticalFlip(p=0) - output = transform(input) - - assert_equal(output, input) + if isinstance(input, cvcuda.Tensor): + assert_equal(F.cvcuda_to_tensor(output), F.cvcuda_to_tensor(input)) + else: + assert_equal(output, input) class TestRotate: diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 1418a6b4953..bef6894de1b 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -11,7 +11,7 @@ from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform -from torchvision.transforms.v2.functional._utils import _FillType +from torchvision.transforms.v2.functional._utils import _FillType, _import_cvcuda, _is_cvcuda_available from ._transform import _RandomApplyTransform from ._utils import ( @@ -30,6 +30,9 @@ query_size, ) +CVCUDA_AVAILABLE = _is_cvcuda_available() +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() class RandomHorizontalFlip(_RandomApplyTransform): """Horizontally flip the input with a given probability. @@ -45,6 +48,9 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip + if CVCUDA_AVAILABLE: + _transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor) + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.horizontal_flip, inpt) @@ -63,6 +69,10 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip + if CVCUDA_AVAILABLE: + _transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor) + + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.vertical_flip, inpt) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4fcb7fabe0d..82ac3a95bb5 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2,7 +2,7 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import PIL.Image import torch @@ -26,7 +26,13 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import _FillTypeJIT, _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_five_ten_crop_kernel_internal, _register_kernel_internal + +CVCUDA_AVAILABLE = _is_cvcuda_available() +if TYPE_CHECKING: + import cvcuda +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -61,6 +67,12 @@ def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor: def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return _FP.hflip(image) +def _horizontal_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + return _import_cvcuda().flip(image, flipCode=1) + + +if CVCUDA_AVAILABLE: + _horizontal_flip_image_cvcuda_registered = _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_image_cvcuda) @_register_kernel_internal(horizontal_flip, tv_tensors.Mask) def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: @@ -150,6 +162,14 @@ def _vertical_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return _FP.vflip(image) +def _vertical_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + return _import_cvcuda().flip(image, flipCode=0) + + +if CVCUDA_AVAILABLE: + _vertical_flip_image_cvcuda_registered = _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_image_cvcuda) + + @_register_kernel_internal(vertical_flip, tv_tensors.Mask) def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: return vertical_flip_image(mask) From 42fcc4114a1c1838591ddb3e1a3121084775d450 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Wed, 26 Nov 2025 02:17:14 -0800 Subject: [PATCH 2/8] WIP: cvcuda flip transforms - pending tech lead review --- test/test_transforms_v2.py | 6 ------ torchvision/transforms/v2/_geometry.py | 8 +------- torchvision/transforms/v2/_transform.py | 6 +++++- .../transforms/v2/functional/_geometry.py | 19 ++++++++++++++++--- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index d7684d7c9a2..f15faafebf8 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1312,14 +1312,10 @@ def test_image_correctness(self, fn, make_input): image = make_input() actual = fn(image) if isinstance(image, cvcuda.Tensor): - # For CVCUDA input expected = F.horizontal_flip(F.cvcuda_to_tensor(image)) - print("actual is ", F.cvcuda_to_tensor(actual)) - print("expected is ", expected) assert_equal(F.cvcuda_to_tensor(actual), expected) else: - # For PIL/regular image input expected = F.to_image(F.horizontal_flip(F.to_pil_image(image))) assert_equal(actual, expected) @@ -1962,11 +1958,9 @@ def test_image_correctness(self, fn, make_input): image = make_input() actual = fn(image) if isinstance(image, cvcuda.Tensor): - # For CVCUDA input expected = F.vertical_flip(F.cvcuda_to_tensor(image)) assert_equal(F.cvcuda_to_tensor(actual), expected) else: - # For PIL/regular image input expected = F.to_image(F.vertical_flip(F.to_pil_image(image))) assert_equal(actual, expected) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index bef6894de1b..cbf3fae6982 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -34,6 +34,7 @@ if CVCUDA_AVAILABLE: cvcuda = _import_cvcuda() + class RandomHorizontalFlip(_RandomApplyTransform): """Horizontally flip the input with a given probability. @@ -48,9 +49,6 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip - if CVCUDA_AVAILABLE: - _transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor) - def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.horizontal_flip, inpt) @@ -69,10 +67,6 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip - if CVCUDA_AVAILABLE: - _transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor) - - def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.vertical_flip, inpt) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..28297e9e4f2 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -12,7 +12,8 @@ from torchvision.utils import _log_api_usage_once from .functional._utils import _get_kernel - +from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available +CVCUDA_AVAILABLE = _is_cvcuda_available() class Transform(nn.Module): """Base class to implement your own v2 transforms. @@ -24,6 +25,9 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) + if CVCUDA_AVAILABLE: + _transformed_types += (_import_cvcuda().Tensor,) + def __init__(self) -> None: super().__init__() diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 82ac3a95bb5..d0e76cdc358 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -26,7 +26,14 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import ( + _FillTypeJIT, + _get_kernel, + _import_cvcuda, + _is_cvcuda_available, + _register_five_ten_crop_kernel_internal, + _register_kernel_internal, +) CVCUDA_AVAILABLE = _is_cvcuda_available() if TYPE_CHECKING: @@ -67,12 +74,16 @@ def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor: def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return _FP.hflip(image) + def _horizontal_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": return _import_cvcuda().flip(image, flipCode=1) if CVCUDA_AVAILABLE: - _horizontal_flip_image_cvcuda_registered = _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_image_cvcuda) + _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)( + _horizontal_flip_image_cvcuda + ) + @_register_kernel_internal(horizontal_flip, tv_tensors.Mask) def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: @@ -167,7 +178,9 @@ def _vertical_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": if CVCUDA_AVAILABLE: - _vertical_flip_image_cvcuda_registered = _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_image_cvcuda) + _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)( + _vertical_flip_image_cvcuda + ) @_register_kernel_internal(vertical_flip, tv_tensors.Mask) From 9423b4d97cb4b28407bd70272d41cb87f350cf88 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Thu, 27 Nov 2025 03:31:47 -0800 Subject: [PATCH 3/8] Address review comments from Nov 26th --- test/common_utils.py | 42 ++++++++++++++++++---- test/test_transforms_v2.py | 47 +++++++++++-------------- torchvision/transforms/v2/_geometry.py | 6 ++++ torchvision/transforms/v2/_transform.py | 6 ++-- 4 files changed, 64 insertions(+), 37 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 8c3c9dd58a8..6bd585d394d 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -188,7 +188,12 @@ def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): def _assert_approx_equal_tensor_to_pil( - tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None + tensor, + pil_image, + tol=1e-5, + msg=None, + agg_method="mean", + allowed_percentage_diff=None, ): # FIXME: this is handled automatically by `assert_close` below. Let's remove this in favor of it # TODO: we could just merge this into _assert_equal_tensor_to_pil @@ -284,8 +289,29 @@ def __init__( mae=False, **other_parameters, ): - if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): - actual, expected = (to_image(input) for input in [actual, expected]) + # Convert PIL images to tv_tensors.Image (regardless of what the other is) + if isinstance(actual, PIL.Image.Image): + actual = to_image(actual) + if isinstance(expected, PIL.Image.Image): + expected = to_image(expected) + + # Convert CV-CUDA tensors to torch.Tensor (regardless of what the other is) + try: + import cvcuda + from torchvision.transforms.v2.functional import cvcuda_to_tensor + + if isinstance(actual, cvcuda.Tensor): + actual = cvcuda_to_tensor(actual) + # Remove batch dimension if it's 1 for easier comparison + if actual.shape[0] == 1: + actual = actual[0] + if isinstance(expected, cvcuda.Tensor): + expected = cvcuda_to_tensor(expected) + # Remove batch dimension if it's 1 for easier comparison + if expected.shape[0] == 1: + expected = expected[0] + except ImportError: + pass super().__init__(actual, expected, **other_parameters) self.mae = mae @@ -400,8 +426,8 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_image_cvcuda(*args, **kwargs): - return to_cvcuda_tensor(make_image(*args, **kwargs)) +def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): + return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): @@ -541,5 +567,9 @@ def ignore_jit_no_profile_information_warning(): # with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore # them. with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning) + warnings.filterwarnings( + "ignore", + message=re.escape("operator() profile_node %"), + category=UserWarning, + ) yield diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f15faafebf8..1a21c08013a 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1241,7 +1241,7 @@ def test_kernel_video(self): make_image_pil, make_image, pytest.param( - functools.partial(make_image_cvcuda, batch_dims=(1,)), + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), ), make_bounding_boxes, @@ -1280,7 +1280,7 @@ def test_functional_signature(self, kernel, input_type): make_image_pil, make_image, pytest.param( - functools.partial(make_image_cvcuda, batch_dims=(1,)), + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), ), make_bounding_boxes, @@ -1296,28 +1296,24 @@ def test_transform(self, make_input, device): @pytest.mark.parametrize( "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] ) - @pytest.mark.parametrize( "make_input", [ make_image, pytest.param( - functools.partial(make_image_cvcuda, batch_dims=(1,)), + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), ), ], ) - def test_image_correctness(self, fn, make_input): image = make_input() actual = fn(image) - if isinstance(image, cvcuda.Tensor): - expected = F.horizontal_flip(F.cvcuda_to_tensor(image)) - assert_equal(F.cvcuda_to_tensor(actual), expected) - - else: - expected = F.to_image(F.horizontal_flip(F.to_pil_image(image))) - assert_equal(actual, expected) + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0] # Remove batch dimension: [1, C, H, W] -> [C, H, W] + expected = F.horizontal_flip(F.to_pil_image(image)) + # CV-CUDA tensors are on CUDA, PIL images are on CPU, so disable device checking + assert_equal(actual, expected, check_device=False) def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes): affine_matrix = np.array( @@ -1374,7 +1370,7 @@ def test_keypoints_correctness(self, fn): make_image_pil, make_image, pytest.param( - functools.partial(make_image_cvcuda, batch_dims=(1,)), + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), ), make_bounding_boxes, @@ -1394,7 +1390,6 @@ def test_transform_noop(self, make_input, device): assert_equal(output, input) - class TestAffine: _EXHAUSTIVE_TYPE_AFFINE_KWARGS = dict( # float, int @@ -1890,7 +1885,7 @@ def test_kernel_video(self): make_image_pil, make_image, pytest.param( - functools.partial(make_image_cvcuda, batch_dims=(1,)), + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), ), make_bounding_boxes, @@ -1929,7 +1924,7 @@ def test_functional_signature(self, kernel, input_type): make_image_pil, make_image, pytest.param( - functools.partial(make_image_cvcuda, batch_dims=(1,)), + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), ), make_bounding_boxes, @@ -1948,21 +1943,19 @@ def test_transform(self, make_input, device): [ make_image, pytest.param( - functools.partial(make_image_cvcuda, batch_dims=(1,)), + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), ), ], ) - def test_image_correctness(self, fn, make_input): image = make_input() actual = fn(image) - if isinstance(image, cvcuda.Tensor): - expected = F.vertical_flip(F.cvcuda_to_tensor(image)) - assert_equal(F.cvcuda_to_tensor(actual), expected) - else: - expected = F.to_image(F.vertical_flip(F.to_pil_image(image))) - assert_equal(actual, expected) + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0] # Remove batch dimension: [1, C, H, W] -> [C, H, W] + expected = F.vertical_flip(F.to_pil_image(image)) + # CV-CUDA tensors are on CUDA, PIL images are on CPU, so disable device checking + assert_equal(actual, expected, check_device=False) def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes): affine_matrix = np.array( @@ -2015,7 +2008,7 @@ def test_keypoints_correctness(self, fn): make_image_pil, make_image, pytest.param( - functools.partial(make_image_cvcuda, batch_dims=(1,)), + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), ), make_bounding_boxes, @@ -7164,7 +7157,7 @@ def test_classification_preset(image_type, label_type, dataset_return_type, to_t out = t(sample) - assert type(out) == type(sample) + assert type(out) is type(sample) if dataset_return_type is tuple: out_image, out_label = out @@ -7475,7 +7468,7 @@ def test_functional(self, input_type): boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size) assert_equal(valid, torch.tensor(expected_valid_mask)) - assert type(valid) == torch.Tensor + assert type(valid) is torch.Tensor assert boxes.shape[0] == sum(valid) assert isinstance(boxes, input_type) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index cbf3fae6982..7bb17aa7f41 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -49,6 +49,9 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip + if CVCUDA_AVAILABLE: + _transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor) + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.horizontal_flip, inpt) @@ -67,6 +70,9 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip + if CVCUDA_AVAILABLE: + _transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor) + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.vertical_flip, inpt) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index 28297e9e4f2..610e7d7e83b 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -12,9 +12,10 @@ from torchvision.utils import _log_api_usage_once from .functional._utils import _get_kernel -from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available + CVCUDA_AVAILABLE = _is_cvcuda_available() + class Transform(nn.Module): """Base class to implement your own v2 transforms. @@ -25,9 +26,6 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) - if CVCUDA_AVAILABLE: - _transformed_types += (_import_cvcuda().Tensor,) - def __init__(self) -> None: super().__init__() From 1e5e9ede063a415b0b48084d02e68191dd3ac25e Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 2 Dec 2025 12:56:08 -0800 Subject: [PATCH 4/8] address the comments iteration one --- test/common_utils.py | 22 +++++++---------- test/test_transforms_v2.py | 24 +++++++------------ torchvision/transforms/v2/_geometry.py | 11 ++++++--- torchvision/transforms/v2/_transform.py | 2 -- .../transforms/v2/functional/_geometry.py | 10 ++------ .../transforms/v2/functional/_utils.py | 8 +++++++ 6 files changed, 34 insertions(+), 43 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 6bd585d394d..7439727a00e 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,7 +20,8 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available from torchvision.utils import _Image_fromarray @@ -188,12 +189,7 @@ def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): def _assert_approx_equal_tensor_to_pil( - tensor, - pil_image, - tol=1e-5, - msg=None, - agg_method="mean", - allowed_percentage_diff=None, + tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None ): # FIXME: this is handled automatically by `assert_close` below. Let's remove this in favor of it # TODO: we could just merge this into _assert_equal_tensor_to_pil @@ -295,23 +291,21 @@ def __init__( if isinstance(expected, PIL.Image.Image): expected = to_image(expected) - # Convert CV-CUDA tensors to torch.Tensor (regardless of what the other is) - try: - import cvcuda - from torchvision.transforms.v2.functional import cvcuda_to_tensor + if _is_cvcuda_available(): + cvcuda = _import_cvcuda() if isinstance(actual, cvcuda.Tensor): - actual = cvcuda_to_tensor(actual) + actual = cvcuda_to_tensor(actual) # No import needed here anymore! # Remove batch dimension if it's 1 for easier comparison if actual.shape[0] == 1: actual = actual[0] + actual = actual.cpu() if isinstance(expected, cvcuda.Tensor): expected = cvcuda_to_tensor(expected) # Remove batch dimension if it's 1 for easier comparison if expected.shape[0] == 1: expected = expected[0] - except ImportError: - pass + expected = expected.cpu() super().__init__(actual, expected, **other_parameters) self.mae = mae diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 1a21c08013a..8c6e0b6da82 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1310,10 +1310,9 @@ def test_image_correctness(self, fn, make_input): image = make_input() actual = fn(image) if make_input is make_image_cvcuda: - image = F.cvcuda_to_tensor(image)[0] # Remove batch dimension: [1, C, H, W] -> [C, H, W] + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.horizontal_flip(F.to_pil_image(image)) - # CV-CUDA tensors are on CUDA, PIL images are on CPU, so disable device checking - assert_equal(actual, expected, check_device=False) + assert_equal(actual, expected) def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes): affine_matrix = np.array( @@ -1384,10 +1383,7 @@ def test_transform_noop(self, make_input, device): input = make_input(device=device) transform = transforms.RandomHorizontalFlip(p=0) output = transform(input) - if isinstance(input, cvcuda.Tensor): - assert_equal(F.cvcuda_to_tensor(output), F.cvcuda_to_tensor(input)) - else: - assert_equal(output, input) + assert_equal(output, input) class TestAffine: @@ -1952,10 +1948,9 @@ def test_image_correctness(self, fn, make_input): image = make_input() actual = fn(image) if make_input is make_image_cvcuda: - image = F.cvcuda_to_tensor(image)[0] # Remove batch dimension: [1, C, H, W] -> [C, H, W] + image = F.cvcuda_to_tensor(image)[0].cpu() expected = F.vertical_flip(F.to_pil_image(image)) - # CV-CUDA tensors are on CUDA, PIL images are on CPU, so disable device checking - assert_equal(actual, expected, check_device=False) + assert_equal(actual, expected) def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes): affine_matrix = np.array( @@ -2022,10 +2017,7 @@ def test_transform_noop(self, make_input, device): input = make_input(device=device) transform = transforms.RandomVerticalFlip(p=0) output = transform(input) - if isinstance(input, cvcuda.Tensor): - assert_equal(F.cvcuda_to_tensor(output), F.cvcuda_to_tensor(input)) - else: - assert_equal(output, input) + assert_equal(output, input) class TestRotate: @@ -7157,7 +7149,7 @@ def test_classification_preset(image_type, label_type, dataset_return_type, to_t out = t(sample) - assert type(out) is type(sample) + assert type(out) == type(sample) if dataset_return_type is tuple: out_image, out_label = out @@ -7468,7 +7460,7 @@ def test_functional(self, input_type): boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size) assert_equal(valid, torch.tensor(expected_valid_mask)) - assert type(valid) is torch.Tensor + assert type(valid) == torch.Tensor assert boxes.shape[0] == sum(valid) assert isinstance(boxes, input_type) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 7bb17aa7f41..55e9804b1eb 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -11,7 +11,12 @@ from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform -from torchvision.transforms.v2.functional._utils import _FillType, _import_cvcuda, _is_cvcuda_available +from torchvision.transforms.v2.functional._utils import ( + _FillType, + _import_cvcuda, + _is_cvcuda_available, + is_cvcuda_tensor, +) from ._transform import _RandomApplyTransform from ._utils import ( @@ -50,7 +55,7 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip if CVCUDA_AVAILABLE: - _transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor) + _transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.horizontal_flip, inpt) @@ -71,7 +76,7 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip if CVCUDA_AVAILABLE: - _transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor) + _transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.vertical_flip, inpt) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index 610e7d7e83b..ac84fcb6c82 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -13,8 +13,6 @@ from .functional._utils import _get_kernel -CVCUDA_AVAILABLE = _is_cvcuda_available() - class Transform(nn.Module): """Base class to implement your own v2 transforms. diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index d0e76cdc358..9a2930e7ece 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -38,8 +38,6 @@ CVCUDA_AVAILABLE = _is_cvcuda_available() if TYPE_CHECKING: import cvcuda -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -80,9 +78,7 @@ def _horizontal_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": if CVCUDA_AVAILABLE: - _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)( - _horizontal_flip_image_cvcuda - ) + _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_image_cvcuda) @_register_kernel_internal(horizontal_flip, tv_tensors.Mask) @@ -178,9 +174,7 @@ def _vertical_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": if CVCUDA_AVAILABLE: - _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)( - _vertical_flip_image_cvcuda - ) + _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_image_cvcuda) @_register_kernel_internal(vertical_flip, tv_tensors.Mask) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ad1eddd258b..b6697bb2f97 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -169,3 +169,11 @@ def _is_cvcuda_available(): return True except ImportError: return False + + +def is_cvcuda_tensor(inpt: Any) -> bool: + try: + cvcuda = _import_cvcuda() + return isinstance(inpt, cvcuda.Tensor) + except ImportError: + return False From f8279c10d18fb5d05bff936eb927008db670cf10 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 2 Dec 2025 17:26:25 -0800 Subject: [PATCH 5/8] Add type ignore comment for the cvcuda import in functional/_geometry.py --- torchvision/transforms/v2/functional/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 9a2930e7ece..0e27218bc89 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -37,7 +37,7 @@ CVCUDA_AVAILABLE = _is_cvcuda_available() if TYPE_CHECKING: - import cvcuda + import cvcuda # type: ignore[import-not-found] def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: From ab89bb5a2c9f16eaa0eb94860d33fd8306d9a72e Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Wed, 3 Dec 2025 12:54:56 -0800 Subject: [PATCH 6/8] address the comments iteration two --- test/common_utils.py | 5 ++--- test/test_transforms_v2.py | 8 ++++++-- torchvision/transforms/v2/_geometry.py | 6 +++--- torchvision/transforms/v2/functional/_utils.py | 2 +- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 7439727a00e..1bd1b3bd522 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -295,14 +295,13 @@ def __init__( cvcuda = _import_cvcuda() if isinstance(actual, cvcuda.Tensor): - actual = cvcuda_to_tensor(actual) # No import needed here anymore! - # Remove batch dimension if it's 1 for easier comparison + actual = cvcuda_to_tensor(actual) + # Remove batch dimension if it's 1 for easier comparison against 3D PIL images if actual.shape[0] == 1: actual = actual[0] actual = actual.cpu() if isinstance(expected, cvcuda.Tensor): expected = cvcuda_to_tensor(expected) - # Remove batch dimension if it's 1 for easier comparison if expected.shape[0] == 1: expected = expected[0] expected = expected.cpu() diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 8c6e0b6da82..e184df89341 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1261,7 +1261,7 @@ def test_functional(self, make_input): (F.horizontal_flip_image, tv_tensors.Image), pytest.param( F._geometry._horizontal_flip_image_cvcuda, - cvcuda.Tensor, + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), ), (F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes), @@ -1271,6 +1271,8 @@ def test_functional(self, make_input): ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._geometry._horizontal_flip_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( @@ -1901,7 +1903,7 @@ def test_functional(self, make_input): (F.vertical_flip_image, tv_tensors.Image), pytest.param( F._geometry._vertical_flip_image_cvcuda, - cvcuda.Tensor, + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), ), (F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes), @@ -1911,6 +1913,8 @@ def test_functional(self, make_input): ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._geometry._vertical_flip_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 55e9804b1eb..afa951decb4 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -15,7 +15,7 @@ _FillType, _import_cvcuda, _is_cvcuda_available, - is_cvcuda_tensor, + _is_cvcuda_tensor, ) from ._transform import _RandomApplyTransform @@ -55,7 +55,7 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip if CVCUDA_AVAILABLE: - _transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,) + _transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.horizontal_flip, inpt) @@ -76,7 +76,7 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip if CVCUDA_AVAILABLE: - _transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,) + _transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.vertical_flip, inpt) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index b6697bb2f97..11480b30ef9 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -171,7 +171,7 @@ def _is_cvcuda_available(): return False -def is_cvcuda_tensor(inpt: Any) -> bool: +def _is_cvcuda_tensor(inpt: Any) -> bool: try: cvcuda = _import_cvcuda() return isinstance(inpt, cvcuda.Tensor) From 39a1dbadc3bdfd71ba43dbafffc4a55042abb442 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Wed, 3 Dec 2025 15:54:14 -0800 Subject: [PATCH 7/8] use _is_cvcuda_tensor instead of if isinstance(actual, cvcuda.Tensor) --- test/common_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 1bd1b3bd522..e3fa464b5ea 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -21,7 +21,7 @@ from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image -from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available +from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor from torchvision.utils import _Image_fromarray @@ -292,15 +292,13 @@ def __init__( expected = to_image(expected) if _is_cvcuda_available(): - cvcuda = _import_cvcuda() - - if isinstance(actual, cvcuda.Tensor): + if _is_cvcuda_tensor(actual): actual = cvcuda_to_tensor(actual) # Remove batch dimension if it's 1 for easier comparison against 3D PIL images if actual.shape[0] == 1: actual = actual[0] actual = actual.cpu() - if isinstance(expected, cvcuda.Tensor): + if _is_cvcuda_tensor(expected): expected = cvcuda_to_tensor(expected) if expected.shape[0] == 1: expected = expected[0] From a9554c70425e4129cb424af7896d6733fd38d7b7 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Wed, 3 Dec 2025 16:38:26 -0800 Subject: [PATCH 8/8] remove unnecessary import --- torchvision/transforms/v2/_geometry.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index afa951decb4..96166e05e9a 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -11,12 +11,7 @@ from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform -from torchvision.transforms.v2.functional._utils import ( - _FillType, - _import_cvcuda, - _is_cvcuda_available, - _is_cvcuda_tensor, -) +from torchvision.transforms.v2.functional._utils import _FillType, _is_cvcuda_available, _is_cvcuda_tensor from ._transform import _RandomApplyTransform from ._utils import ( @@ -36,8 +31,6 @@ ) CVCUDA_AVAILABLE = _is_cvcuda_available() -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() class RandomHorizontalFlip(_RandomApplyTransform):