diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3ce603c3ed2..2459a60b54c 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -5824,7 +5824,18 @@ def test_kernel_image(self, dtype, device): def test_kernel_video(self): check_kernel(F.invert_video, make_video()) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image, + make_image_pil, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) def test_functional(self, make_input): check_functional(F.invert, make_input()) @@ -5835,20 +5846,51 @@ def test_functional(self, make_input): (F._color._invert_image_pil, PIL.Image.Image), (F.invert_image, tv_tensors.Image), (F.invert_video, tv_tensors.Video), + pytest.param( + F._color._invert_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._color._invert_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.invert, kernel=kernel, input_type=input_type) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) def test_transform(self, make_input): check_transform(transforms.RandomInvert(p=1), make_input()) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("fn", [F.invert, transform_cls_to_functional(transforms.RandomInvert, p=1)]) - def test_correctness_image(self, fn): - image = make_image(dtype=torch.uint8, device="cpu") + def test_correctness_image(self, make_input, fn): + image = make_input(dtype=torch.uint8, device="cpu") actual = fn(image) + + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0].cpu() + expected = F.to_image(F.invert(F.to_pil_image(image))) assert_equal(actual, expected) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index bf4ae55d232..cad1073d155 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -5,6 +5,7 @@ import torch from torchvision import transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2.functional._utils import _is_cvcuda_tensor from ._transform import _RandomApplyTransform from ._utils import query_chw @@ -282,6 +283,8 @@ class RandomInvert(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomInvert + _transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,) + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.invert, inpt) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index be254c0d63a..92e186eab5d 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import PIL.Image import torch from torch.nn.functional import conv2d @@ -9,7 +11,13 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: @@ -680,6 +688,43 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: return invert_image(video) +_invert_cvcuda_tensors: dict[str, "cvcuda.Tensor"] = {} + + +def _invert_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + # save the tensors into a dictionary only if CV-CUDA is actually used + # we save these here, since they are static and small in size + if "base" not in _invert_cvcuda_tensors: + _invert_cvcuda_tensors["base"] = cvcuda.as_tensor( + torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device="cuda").reshape(1, 1, 1, 3).contiguous(), "NHWC" + ) + if "scale" not in _invert_cvcuda_tensors: + _invert_cvcuda_tensors["scale"] = cvcuda.as_tensor( + torch.tensor([-1.0, -1.0, -1.0], dtype=torch.float32, device="cuda").reshape(1, 1, 1, 3).contiguous(), + "NHWC", + ) + + base = _invert_cvcuda_tensors["base"] + scale = _invert_cvcuda_tensors["scale"] + + if image.dtype == cvcuda.Type.U8: + shift = 255.0 + elif image.dtype == cvcuda.Type.F32: + shift = 1.0 + else: + raise ValueError(f"Input image dtype must be uint8 or float32, got {image.dtype}") + + # Use normalize to invert: output = (input - base) * scale * global_scale + shift + # For inversion: output = (input - 0) * (-1) * 1 + shift = shift - input + return cvcuda.normalize(image, base=base, scale=scale, globalscale=1.0, globalshift=shift) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(invert, _import_cvcuda().Tensor)(_invert_image_cvcuda) + + def permute_channels(inpt: torch.Tensor, permutation: list[int]) -> torch.Tensor: """Permute the channels of the input according to the given permutation.