diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3ce603c3ed2..5d87646eba3 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -4074,14 +4074,28 @@ def test_kernel_uint8(self, make_input): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image, make_video], + [ + make_image_tensor, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available") + ), + ], ) def test_functional_float(self, make_input): check_functional(F.gaussian_noise, make_input(dtype=torch.float32)) @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image, make_video], + [ + make_image_tensor, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available") + ), + ], ) def test_functional_uint8(self, make_input): check_functional(F.gaussian_noise, make_input(dtype=torch.uint8)) @@ -4092,14 +4106,28 @@ def test_functional_uint8(self, make_input): (F.gaussian_noise, torch.Tensor), (F.gaussian_noise_image, tv_tensors.Image), (F.gaussian_noise_video, tv_tensors.Video), + pytest.param( + F._misc._gaussian_noise_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._misc._gaussian_noise_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.gaussian_noise, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image, make_video], + [ + make_image_tensor, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available") + ), + ], ) def test_transform_float(self, make_input): def adapter(_, input, __): @@ -4117,7 +4145,14 @@ def adapter(_, input, __): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image, make_video], + [ + make_image_tensor, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available") + ), + ], ) def test_transform_uint8(self, make_input): def adapter(_, input, __): diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 305149c87b1..df0a0561962 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -9,6 +9,7 @@ from torchvision import transforms as _transforms, tv_tensors from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2.functional._utils import _is_cvcuda_tensor from ._utils import ( _parse_labels_getter, @@ -240,6 +241,8 @@ class GaussianNoise(Transform): Default is True. """ + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None: super().__init__() self.mean = mean diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index daf263df046..a3863dadc0b 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, TYPE_CHECKING import PIL.Image import torch @@ -13,7 +13,12 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] def normalize( @@ -231,6 +236,35 @@ def _gaussian_noise_pil( raise ValueError("Gaussian Noise is not implemented for PIL images.") +def _gaussian_noise_image_cvcuda( + image: "cvcuda.Tensor", + mean: float = 0.0, + sigma: float = 0.1, + clip: bool = True, +) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + batch_size = image.shape[0] + mu_tensor = cvcuda.as_tensor(torch.full((batch_size,), mean, dtype=torch.float32).cuda(), "N") + sigma_tensor = cvcuda.as_tensor(torch.full((batch_size,), sigma, dtype=torch.float32).cuda(), "N") + + # per-channel means each channel gets unique random noise, same behavior as torch.randn_like + # produce a seed with torch RNG, if seed is manually set then this will be deterministic + # note: clip is not supported in CV-CUDA, so we don't need to clamp the values + # by default, clamping is done for floats, and uint8 overflows so is clamped from 0-255 anyways + return cvcuda.gaussiannoise( + image, + mu=mu_tensor, + sigma=sigma_tensor, + per_channel=True, + seed=int(torch.empty((), dtype=torch.int64).random_().item()), + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(gaussian_noise, _import_cvcuda().Tensor)(_gaussian_noise_image_cvcuda) + + def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: """See :func:`~torchvision.transforms.v2.ToDtype` for details.""" if torch.jit.is_scripting():