diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3ce603c3ed2..cfac685ffcd 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -2822,7 +2822,18 @@ class TestAdjustBrightness: def test_kernel(self, kernel, make_input, dtype, device): check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_functional(self, make_input): check_functional(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) @@ -2833,19 +2844,42 @@ def test_functional(self, make_input): (F._color._adjust_brightness_image_pil, PIL.Image.Image), (F.adjust_brightness_image, tv_tensors.Image), (F.adjust_brightness_video, tv_tensors.Video), + pytest.param( + F._color._adjust_brightness_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._color._adjust_brightness_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS) - def test_image_correctness(self, brightness_factor): - image = make_image(dtype=torch.uint8, device="cpu") + def test_image_correctness(self, make_input, brightness_factor): + image = make_input(dtype=torch.uint8, device="cpu") actual = F.adjust_brightness(image, brightness_factor=brightness_factor) + + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0].cpu() + expected = F.to_image(F.adjust_brightness(F.to_pil_image(image), brightness_factor=brightness_factor)) - torch.testing.assert_close(actual, expected) + if make_input is make_image_cvcuda: + assert_close(actual, expected, rtol=0, atol=1) + else: + assert_close(actual, expected) class TestCutMixMixUp: @@ -6053,7 +6087,18 @@ def test_kernel_image(self, dtype, device): def test_kernel_video(self): check_kernel(F.adjust_contrast_video, make_video(), contrast_factor=0.5) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image, + make_image_pil, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_functional(self, make_input): check_functional(F.adjust_contrast, make_input(), contrast_factor=0.5) @@ -6064,9 +6109,16 @@ def test_functional(self, make_input): (F._color._adjust_contrast_image_pil, PIL.Image.Image), (F.adjust_contrast_image, tv_tensors.Image), (F.adjust_contrast_video, tv_tensors.Video), + pytest.param( + F._color._adjust_contrast_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._color._adjust_contrast_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.adjust_contrast, kernel=kernel, input_type=input_type) def test_functional_error(self): @@ -6076,11 +6128,24 @@ def test_functional_error(self): with pytest.raises(ValueError, match="is not non-negative"): F.adjust_contrast(make_image(), contrast_factor=-1) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @pytest.mark.parametrize("contrast_factor", [0.1, 0.5, 1.0]) - def test_correctness_image(self, contrast_factor): - image = make_image(dtype=torch.uint8, device="cpu") + def test_correctness_image(self, make_input, contrast_factor): + image = make_input(dtype=torch.uint8, device="cpu") actual = F.adjust_contrast(image, contrast_factor=contrast_factor) + + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0].cpu() + expected = F.to_image(F.adjust_contrast(F.to_pil_image(image), contrast_factor=contrast_factor)) assert_close(actual, expected, rtol=0, atol=1) @@ -6135,7 +6200,18 @@ def test_kernel_image(self, dtype, device): def test_kernel_video(self): check_kernel(F.adjust_hue_video, make_video(), hue_factor=0.25) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image, + make_image_pil, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_functional(self, make_input): check_functional(F.adjust_hue, make_input(), hue_factor=0.25) @@ -6146,9 +6222,16 @@ def test_functional(self, make_input): (F._color._adjust_hue_image_pil, PIL.Image.Image), (F.adjust_hue_image, tv_tensors.Image), (F.adjust_hue_video, tv_tensors.Video), + pytest.param( + F._color._adjust_hue_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._color._adjust_hue_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.adjust_hue, kernel=kernel, input_type=input_type) def test_functional_error(self): @@ -6159,11 +6242,25 @@ def test_functional_error(self): with pytest.raises(ValueError, match=re.escape("is not in [-0.5, 0.5]")): F.adjust_hue(make_image(), hue_factor=hue_factor) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @pytest.mark.parametrize("hue_factor", [-0.5, -0.3, 0.0, 0.2, 0.5]) - def test_correctness_image(self, hue_factor): - image = make_image(dtype=torch.uint8, device="cpu") + def test_correctness_image(self, make_input, hue_factor): + image = make_input(dtype=torch.uint8, device="cpu") actual = F.adjust_hue(image, hue_factor=hue_factor) + + if make_input is make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual)[0].cpu() + image = F.cvcuda_to_tensor(image)[0].cpu() + expected = F.to_image(F.adjust_hue(F.to_pil_image(image), hue_factor=hue_factor)) mae = (actual.float() - expected.float()).abs().mean() @@ -6179,7 +6276,18 @@ def test_kernel_image(self, dtype, device): def test_kernel_video(self): check_kernel(F.adjust_saturation_video, make_video(), saturation_factor=0.5) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image, + make_image_pil, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) def test_functional(self, make_input): check_functional(F.adjust_saturation, make_input(), saturation_factor=0.5) @@ -6190,9 +6298,16 @@ def test_functional(self, make_input): (F._color._adjust_saturation_image_pil, PIL.Image.Image), (F.adjust_saturation_image, tv_tensors.Image), (F.adjust_saturation_video, tv_tensors.Video), + pytest.param( + F._color._adjust_saturation_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._color._adjust_saturation_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.adjust_saturation, kernel=kernel, input_type=input_type) def test_functional_error(self): @@ -6202,11 +6317,25 @@ def test_functional_error(self): with pytest.raises(ValueError, match="is not non-negative"): F.adjust_saturation(make_image(), saturation_factor=-1) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) @pytest.mark.parametrize("saturation_factor", [0.1, 0.5, 1.0]) - def test_correctness_image(self, saturation_factor): - image = make_image(dtype=torch.uint8, device="cpu") + def test_correctness_image(self, make_input, color_space, saturation_factor): + image = make_input(dtype=torch.uint8, color_space=color_space, device="cpu") actual = F.adjust_saturation(image, saturation_factor=saturation_factor) + + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0].cpu() + expected = F.to_image(F.adjust_saturation(F.to_pil_image(image), saturation_factor=saturation_factor)) assert_close(actual, expected, rtol=0, atol=1) @@ -6339,7 +6468,15 @@ def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip): class TestColorJitter: @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], ) @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -6383,12 +6520,21 @@ def test_transform_error(self): with pytest.raises(ValueError, match="values should be between"): transforms.ColorJitter(hue=1) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @pytest.mark.parametrize("brightness", [None, 0.1, (0.2, 0.3)]) @pytest.mark.parametrize("contrast", [None, 0.4, (0.5, 0.6)]) @pytest.mark.parametrize("saturation", [None, 0.7, (0.8, 0.9)]) @pytest.mark.parametrize("hue", [None, 0.3, (-0.1, 0.2)]) - def test_transform_correctness(self, brightness, contrast, saturation, hue): - image = make_image(dtype=torch.uint8, device="cpu") + def test_transform_correctness(self, make_input, brightness, contrast, saturation, hue): + image = make_input(dtype=torch.uint8, device="cpu") transform = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) @@ -6396,11 +6542,18 @@ def test_transform_correctness(self, brightness, contrast, saturation, hue): torch.manual_seed(0) actual = transform(image) + if make_input is make_image_cvcuda: + actual = F.cvcuda_to_tensor(actual)[0].cpu() + image = F.cvcuda_to_tensor(image)[0].cpu() + torch.manual_seed(0) expected = F.to_image(transform(F.to_pil_image(image))) mae = (actual.float() - expected.float()).abs().mean() - assert mae < 2 + mae_threshold = 2 + if make_input is make_image_cvcuda: + mae_threshold = 3 + assert mae < mae_threshold, f"MAE: {mae}" class TestRgbToGrayscale: diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index bf4ae55d232..083ec52dc8b 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -5,6 +5,7 @@ import torch from torchvision import transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2.functional._utils import _is_cvcuda_tensor from ._transform import _RandomApplyTransform from ._utils import query_chw @@ -96,6 +97,8 @@ class ColorJitter(Transform): _v1_transform_cls = _transforms.ColorJitter + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def _extract_params_for_v1_transform(self) -> dict[str, Any]: return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()} diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index be254c0d63a..9fd3ddbeae5 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import PIL.Image import torch from torch.nn.functional import conv2d @@ -9,7 +11,15 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: @@ -135,6 +145,19 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to return adjust_brightness_image(video, brightness_factor=brightness_factor) +def _adjust_brightness_image_cvcuda(image: "cvcuda.Tensor", brightness_factor: float) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + cv_brightness = torch.tensor([brightness_factor], dtype=torch.float32, device="cuda") + cv_brightness = cvcuda.as_tensor(cv_brightness, "N") + + return cvcuda.brightness_contrast(image, brightness=cv_brightness) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(adjust_brightness, _import_cvcuda().Tensor)(_adjust_brightness_image_cvcuda) + + def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor: """Adjust saturation.""" if torch.jit.is_scripting(): @@ -174,6 +197,37 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to return adjust_saturation_image(video, saturation_factor=saturation_factor) +def _adjust_saturation_image_cvcuda(image: "cvcuda.Tensor", saturation_factor: float) -> "cvcuda.Tensor": + if saturation_factor < 0: + raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") + + c = image.shape[3] + if c not in [1, 3, 4]: + raise TypeError(f"Input image tensor permitted channel values are 1, 3, or 4, but found {c}") + + if c == 1: # Match PIL behaviour + return image + + # grayscale weights + sf = saturation_factor + r, g, b = 0.2989, 0.587, 0.114 + twist_data = [ + [sf + (1 - sf) * r, (1 - sf) * g, (1 - sf) * b, 0.0], + [(1 - sf) * r, sf + (1 - sf) * g, (1 - sf) * b, 0.0], + [(1 - sf) * r, (1 - sf) * g, sf + (1 - sf) * b, 0.0], + ] + twist_tensor = cvcuda.as_tensor( + torch.tensor(twist_data, dtype=torch.float32, device="cuda"), + "HW", + ) + + return cvcuda.color_twist(image, twist_tensor) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(adjust_saturation, cvcuda.Tensor)(_adjust_saturation_image_cvcuda) + + def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: """See :class:`~torchvision.transforms.RandomAutocontrast`""" if torch.jit.is_scripting(): @@ -213,6 +267,48 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. return adjust_contrast_image(video, contrast_factor=contrast_factor) +def _adjust_contrast_image_cvcuda(image: "cvcuda.Tensor", contrast_factor: float) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + if contrast_factor < 0: + raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") + + c = image.shape[3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") + + # for now only float32 is supported for cvcuda, add float16 in the future + fp = image.dtype == cvcuda.Type.F32 + + if c == 3 and not fp: + grayscale = cvcuda.cvtcolor(image, cvcuda.ColorConversion.RGB2GRAY) + elif c == 3 and fp: + grayscale = cvcuda.cvtcolor( + cvcuda.convertto(image, cvcuda.Type.U8, scale=1.0 / 255.0, offset=0.0), cvcuda.ColorConversion.RGB2GRAY + ) + else: + grayscale = image + + contrast = cvcuda.as_tensor(torch.tensor([contrast_factor], dtype=torch.float32, device="cuda")) + + # torchvision uses the mean of the image as the center of contrast + # we will compute that here using torch as well for consistency + torch_image = torch.as_tensor(grayscale.cuda()) + mean = torch.mean(torch_image.float()) + contrast_center = cvcuda.as_tensor(torch.tensor([mean.item()], dtype=torch.float32, device="cuda")) + + result = cvcuda.brightness_contrast(image, contrast=contrast, contrast_center=contrast_center) + + if fp: + result = cvcuda.convertto(result, cvcuda.Type.F32, scale=256.0 - 1e-3, offset=0.0) + + return result + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(adjust_contrast, _import_cvcuda().Tensor)(_adjust_contrast_image_cvcuda) + + def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor: """See :class:`~torchvision.transforms.RandomAdjustSharpness`""" if torch.jit.is_scripting(): @@ -404,6 +500,44 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: return adjust_hue_image(video, hue_factor=hue_factor) +def _adjust_hue_image_cvcuda(image: "cvcuda.Tensor", hue_factor: float) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") + + c = image.shape[3] + if c not in [1, 3, 4]: + raise TypeError(f"Input image tensor permitted channel values are 1, 3, or 4, but found {c}") + + if c == 1: # Match PIL behaviour + return image + + # for now only float32 is supported for cvcuda, add float16 in the future + fp = image.dtype == cvcuda.Type.F32 + if fp: + image = cvcuda.convertto(image, cvcuda.Type.U8, scale=1.0 / 255.0, offset=0.0) + + # no native adjust_hue, use CV-CUDA for color converison, use torch for elementwise operations + # CV-CUDA accelerates the HSV conversion + hsv = cvcuda.cvtcolor(image, cvcuda.ColorConversion.RGB2HSV) + # then use torch for elementwise operations + hsv_torch = torch.as_tensor(hsv.cuda()).float() + hsv_torch[..., 0] = (hsv_torch[..., 0] + hue_factor * 180) % 180 + # convert back to cvcuda tensor and accelerate the HSV2RGB conversion + hsv_modified = cvcuda.as_tensor(hsv_torch.to(torch.uint8), "NHWC") + rgb = cvcuda.cvtcolor(hsv_modified, cvcuda.ColorConversion.HSV2RGB) + + if fp: + rgb = cvcuda.convertto(rgb, cvcuda.Type.F32, scale=256.0 - 1e-3, offset=0.0) + + return rgb + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(adjust_hue, _import_cvcuda().Tensor)(_adjust_hue_image_cvcuda) + + def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: """Adjust gamma.""" if torch.jit.is_scripting():