diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3ce603c3ed2..b8523a56fa2 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -6409,7 +6409,17 @@ class TestRgbToGrayscale: def test_kernel_image(self, dtype, device): check_kernel(F.rgb_to_grayscale_image, make_image(dtype=dtype, device=device)) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_functional(self, make_input): check_functional(F.rgb_to_grayscale, make_input()) @@ -6419,23 +6429,53 @@ def test_functional(self, make_input): (F.rgb_to_grayscale_image, torch.Tensor), (F._color._rgb_to_grayscale_image_pil, PIL.Image.Image), (F.rgb_to_grayscale_image, tv_tensors.Image), + pytest.param( + F._color._rgb_to_grayscale_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._color._rgb_to_grayscale_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.rgb_to_grayscale, kernel=kernel, input_type=input_type) @pytest.mark.parametrize("transform", [transforms.Grayscale(), transforms.RandomGrayscale(p=1)]) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_transform(self, transform, make_input): check_transform(transform, make_input()) @pytest.mark.parametrize("num_output_channels", [1, 3]) @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)]) - def test_image_correctness(self, num_output_channels, color_space, fn): - image = make_image(dtype=torch.uint8, device="cpu", color_space=color_space) + def test_image_correctness(self, num_output_channels, color_space, make_input, fn): + image = make_input(dtype=torch.uint8, device="cpu", color_space=color_space) actual = fn(image, num_output_channels=num_output_channels) + + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0].cpu() + expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels)) assert_equal(actual, expected, rtol=0, atol=1) @@ -6473,7 +6513,17 @@ class TestGrayscaleToRgb: def test_kernel_image(self, dtype, device): check_kernel(F.grayscale_to_rgb_image, make_image(dtype=dtype, device=device)) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_functional(self, make_input): check_functional(F.grayscale_to_rgb, make_input()) @@ -6483,20 +6533,50 @@ def test_functional(self, make_input): (F.rgb_to_grayscale_image, torch.Tensor), (F._color._rgb_to_grayscale_image_pil, PIL.Image.Image), (F.rgb_to_grayscale_image, tv_tensors.Image), + pytest.param( + F._color._rgb_to_grayscale_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._color._rgb_to_grayscale_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.grayscale_to_rgb, kernel=kernel, input_type=input_type) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_transform(self, make_input): check_transform(transforms.RGB(), make_input(color_space="GRAY")) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.RGB)]) - def test_image_correctness(self, fn): - image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY") + def test_image_correctness(self, make_input, fn): + image = make_input(dtype=torch.uint8, device="cpu", color_space="GRAY") actual = fn(image) + + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0].cpu() + expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image))) assert_equal(actual, expected, rtol=0, atol=1) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index bf4ae55d232..d71f32ee398 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -5,6 +5,7 @@ import torch from torchvision import transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2.functional._utils import _is_cvcuda_tensor from ._transform import _RandomApplyTransform from ._utils import query_chw @@ -22,6 +23,8 @@ class Grayscale(Transform): _v1_transform_cls = _transforms.Grayscale + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def __init__(self, num_output_channels: int = 1): super().__init__() self.num_output_channels = num_output_channels @@ -44,6 +47,8 @@ class RandomGrayscale(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomGrayscale + _transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,) + def __init__(self, p: float = 0.1) -> None: super().__init__(p=p) @@ -62,6 +67,8 @@ class RGB(Transform): to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions """ + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def __init__(self): super().__init__() diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..0e7d1199c74 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -16,7 +16,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index be254c0d63a..392bc06da9f 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import PIL.Image import torch from torch.nn.functional import conv2d @@ -9,7 +11,13 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: @@ -63,6 +71,35 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int return _FP.to_grayscale(image, num_output_channels=num_output_channels) +def _rgb_to_grayscale_image_cvcuda( + image: "cvcuda.Tensor", + num_output_channels: int = 1, +) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + if not (num_output_channels == 1 or num_output_channels == 3): + raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") + + if image.shape[3] == 1 and num_output_channels == 1: + # no work to do if already a single channel + return image + + if image.shape[3] == 1 and num_output_channels == 3: + # just duplicate the channels + return cvcuda.cvtcolor(image, cvcuda.ColorConversion.GRAY2RGB) + + gray = cvcuda.cvtcolor(image, cvcuda.ColorConversion.RGB2GRAY) + + if num_output_channels == 3: + gray = cvcuda.cvtcolor(gray, cvcuda.ColorConversion.GRAY2RGB) + + return gray + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(rgb_to_grayscale, _import_cvcuda().Tensor)(_rgb_to_grayscale_image_cvcuda) + + def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor: """See :class:`~torchvision.transforms.v2.RGB` for details.""" if torch.jit.is_scripting(): @@ -89,6 +126,22 @@ def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return image.convert(mode="RGB") +def _grayscale_to_rgb_image_cvcuda( + image: "cvcuda.Tensor", +) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + if image.shape[3] == 3: + # if we already have RGB channels, just return the image + return image + + return cvcuda.cvtcolor(image, cvcuda.ColorConversion.GRAY2RGB) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(grayscale_to_rgb, _import_cvcuda().Tensor)(_grayscale_to_rgb_image_cvcuda) + + def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: ratio = float(ratio) fp = image1.is_floating_point() diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6b8f19f12f4..107dc837d29 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,6 +51,14 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) +def _get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: + return [image.shape[3], image.shape[1], image.shape[2]] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_dimensions, _import_cvcuda().Tensor)(_get_dimensions_image_cvcuda) + + def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image(inpt)