Skip to content

Commit 0146044

Browse files
committed
gaussian_noise cvcuda backend
1 parent fbea584 commit 0146044

File tree

2 files changed

+68
-4
lines changed

2 files changed

+68
-4
lines changed

test/test_transforms_v2.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4075,14 +4075,28 @@ def test_kernel_uint8(self, make_input):
40754075

40764076
@pytest.mark.parametrize(
40774077
"make_input",
4078-
[make_image_tensor, make_image, make_video],
4078+
[
4079+
make_image_tensor,
4080+
make_image,
4081+
make_video,
4082+
pytest.param(
4083+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
4084+
),
4085+
],
40794086
)
40804087
def test_functional_float(self, make_input):
40814088
check_functional(F.gaussian_noise, make_input(dtype=torch.float32))
40824089

40834090
@pytest.mark.parametrize(
40844091
"make_input",
4085-
[make_image_tensor, make_image, make_video],
4092+
[
4093+
make_image_tensor,
4094+
make_image,
4095+
make_video,
4096+
pytest.param(
4097+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
4098+
),
4099+
],
40864100
)
40874101
def test_functional_uint8(self, make_input):
40884102
check_functional(F.gaussian_noise, make_input(dtype=torch.uint8))
@@ -4093,14 +4107,28 @@ def test_functional_uint8(self, make_input):
40934107
(F.gaussian_noise, torch.Tensor),
40944108
(F.gaussian_noise_image, tv_tensors.Image),
40954109
(F.gaussian_noise_video, tv_tensors.Video),
4110+
pytest.param(
4111+
F._misc._gaussian_noise_cvcuda,
4112+
"cvcuda.Tensor",
4113+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available"),
4114+
),
40964115
],
40974116
)
40984117
def test_functional_signature(self, kernel, input_type):
4118+
if input_type == "cvcuda.Tensor":
4119+
input_type = _import_cvcuda().Tensor
40994120
check_functional_kernel_signature_match(F.gaussian_noise, kernel=kernel, input_type=input_type)
41004121

41014122
@pytest.mark.parametrize(
41024123
"make_input",
4103-
[make_image_tensor, make_image, make_video],
4124+
[
4125+
make_image_tensor,
4126+
make_image,
4127+
make_video,
4128+
pytest.param(
4129+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
4130+
),
4131+
],
41044132
)
41054133
def test_transform_float(self, make_input):
41064134
def adapter(_, input, __):
@@ -4118,7 +4146,14 @@ def adapter(_, input, __):
41184146

41194147
@pytest.mark.parametrize(
41204148
"make_input",
4121-
[make_image_tensor, make_image, make_video],
4149+
[
4150+
make_image_tensor,
4151+
make_image,
4152+
make_video,
4153+
pytest.param(
4154+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
4155+
),
4156+
],
41224157
)
41234158
def test_transform_uint8(self, make_input):
41244159
def adapter(_, input, __):

torchvision/transforms/v2/functional/_misc.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,35 @@ def _gaussian_noise_pil(
238238
raise ValueError("Gaussian Noise is not implemented for PIL images.")
239239

240240

241+
def _gaussian_noise_cvcuda(
242+
image: "cvcuda.Tensor",
243+
mean: float = 0.0,
244+
sigma: float = 0.1,
245+
clip: bool = True,
246+
) -> "cvcuda.Tensor":
247+
cvcuda = _import_cvcuda()
248+
249+
batch_size = image.shape[0]
250+
mu_tensor = cvcuda.as_tensor(torch.full((batch_size,), mean, dtype=torch.float32).cuda(), "N")
251+
sigma_tensor = cvcuda.as_tensor(torch.full((batch_size,), sigma, dtype=torch.float32).cuda(), "N")
252+
253+
# per-channel means each channel gets unique random noise, same behavior as torch.randn_like
254+
# produce a seed with torch RNG, if seed is manually set then this will be deterministic
255+
# note: clip is not supported in CV-CUDA, so we don't need to clamp the values
256+
# by default, clamping is done for floats, and uint8 overflows so is clamped from 0-255 anyways
257+
return cvcuda.gaussiannoise(
258+
image,
259+
mu=mu_tensor,
260+
sigma=sigma_tensor,
261+
per_channel=True,
262+
seed=int(torch.empty((), dtype=torch.int64).random_().item()),
263+
)
264+
265+
266+
if CVCUDA_AVAILABLE:
267+
_register_kernel_internal(gaussian_noise, _import_cvcuda().Tensor)(_gaussian_noise_cvcuda)
268+
269+
241270
def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
242271
"""See :func:`~torchvision.transforms.v2.ToDtype` for details."""
243272
if torch.jit.is_scripting():

0 commit comments

Comments
 (0)