Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -41,7 +42,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -5824,7 +5824,18 @@ def test_kernel_image(self, dtype, device):
def test_kernel_video(self):
check_kernel(F.invert_video, make_video())

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_image_pil,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_functional(self, make_input):
check_functional(F.invert, make_input())

Expand All @@ -5835,20 +5846,51 @@ def test_functional(self, make_input):
(F._color._invert_image_pil, PIL.Image.Image),
(F.invert_image, tv_tensors.Image),
(F.invert_video, tv_tensors.Video),
pytest.param(
F._color._invert_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._color._invert_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.invert, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image_pil,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_transform(self, make_input):
check_transform(transforms.RandomInvert(p=1), make_input())

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
@pytest.mark.parametrize("fn", [F.invert, transform_cls_to_functional(transforms.RandomInvert, p=1)])
def test_correctness_image(self, fn):
image = make_image(dtype=torch.uint8, device="cpu")
def test_correctness_image(self, make_input, fn):
image = make_input(dtype=torch.uint8, device="cpu")

actual = fn(image)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.invert(F.to_pil_image(image)))

assert_equal(actual, expected)
Expand Down
3 changes: 3 additions & 0 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torchvision import transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.functional._utils import _is_cvcuda_tensor

from ._transform import _RandomApplyTransform
from ._utils import query_chw
Expand Down Expand Up @@ -282,6 +283,8 @@ class RandomInvert(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomInvert

_transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return self._call_kernel(F.invert, inpt)

Expand Down
47 changes: 46 additions & 1 deletion torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING

import PIL.Image
import torch
from torch.nn.functional import conv2d
Expand All @@ -9,7 +11,13 @@

from ._misc import _num_value_bits, to_dtype_image
from ._type_conversion import pil_to_tensor, to_pil_image
from ._utils import _get_kernel, _register_kernel_internal
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal


CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]


def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
Expand Down Expand Up @@ -680,6 +688,43 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image(video)


_invert_cvcuda_tensors: dict[str, "cvcuda.Tensor"] = {}


def _invert_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()

# save the tensors into a dictionary only if CV-CUDA is actually used
# we save these here, since they are static and small in size
if "base" not in _invert_cvcuda_tensors:
_invert_cvcuda_tensors["base"] = cvcuda.as_tensor(
torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device="cuda").reshape(1, 1, 1, 3).contiguous(), "NHWC"
)
if "scale" not in _invert_cvcuda_tensors:
_invert_cvcuda_tensors["scale"] = cvcuda.as_tensor(
torch.tensor([-1.0, -1.0, -1.0], dtype=torch.float32, device="cuda").reshape(1, 1, 1, 3).contiguous(),
"NHWC",
)

base = _invert_cvcuda_tensors["base"]
scale = _invert_cvcuda_tensors["scale"]

if image.dtype == cvcuda.Type.U8:
shift = 255.0
elif image.dtype == cvcuda.Type.F32:
shift = 1.0
else:
raise ValueError(f"Input image dtype must be uint8 or float32, got {image.dtype}")

# Use normalize to invert: output = (input - base) * scale * global_scale + shift
# For inversion: output = (input - 0) * (-1) * 1 + shift = shift - input
return cvcuda.normalize(image, base=base, scale=scale, globalscale=1.0, globalshift=shift)


if CVCUDA_AVAILABLE:
_register_kernel_internal(invert, _import_cvcuda().Tensor)(_invert_image_cvcuda)


def permute_channels(inpt: torch.Tensor, permutation: list[int]) -> torch.Tensor:
"""Permute the channels of the input according to the given permutation.

Expand Down