Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 94 additions & 9 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -41,7 +42,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -6409,7 +6409,17 @@ class TestRgbToGrayscale:
def test_kernel_image(self, dtype, device):
check_kernel(F.rgb_to_grayscale_image, make_image(dtype=dtype, device=device))

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_functional(self, make_input):
check_functional(F.rgb_to_grayscale, make_input())

Expand All @@ -6419,23 +6429,58 @@ def test_functional(self, make_input):
(F.rgb_to_grayscale_image, torch.Tensor),
(F._color._rgb_to_grayscale_image_pil, PIL.Image.Image),
(F.rgb_to_grayscale_image, tv_tensors.Image),
pytest.param(
F._color._rgb_to_grayscale_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._color._rgb_to_grayscale_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.rgb_to_grayscale, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("transform", [transforms.Grayscale(), transforms.RandomGrayscale(p=1)])
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_transform(self, transform, make_input):
if make_input is make_image_cvcuda and isinstance(transform, transforms.RandomGrayscale):
pytest.skip("CV-CUDA does not support RandomGrayscale, will have num_output_channels == 3")
check_transform(transform, make_input())

@pytest.mark.parametrize("num_output_channels", [1, 3])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)])
def test_image_correctness(self, num_output_channels, color_space, fn):
image = make_image(dtype=torch.uint8, device="cpu", color_space=color_space)
def test_image_correctness(self, num_output_channels, color_space, make_input, fn):
if make_input is make_image_cvcuda and num_output_channels == 3:
pytest.skip("CV-CUDA does not support num_output_channels == 3")

image = make_input(dtype=torch.uint8, device="cpu", color_space=color_space)

actual = fn(image, num_output_channels=num_output_channels)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels))

assert_equal(actual, expected, rtol=0, atol=1)
Expand Down Expand Up @@ -6473,7 +6518,17 @@ class TestGrayscaleToRgb:
def test_kernel_image(self, dtype, device):
check_kernel(F.grayscale_to_rgb_image, make_image(dtype=dtype, device=device))

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_functional(self, make_input):
check_functional(F.grayscale_to_rgb, make_input())

Expand All @@ -6483,20 +6538,50 @@ def test_functional(self, make_input):
(F.rgb_to_grayscale_image, torch.Tensor),
(F._color._rgb_to_grayscale_image_pil, PIL.Image.Image),
(F.rgb_to_grayscale_image, tv_tensors.Image),
pytest.param(
F._color._rgb_to_grayscale_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._color._rgb_to_grayscale_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.grayscale_to_rgb, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_transform(self, make_input):
check_transform(transforms.RGB(), make_input(color_space="GRAY"))

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.RGB)])
def test_image_correctness(self, fn):
image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")
def test_image_correctness(self, make_input, fn):
image = make_input(dtype=torch.uint8, device="cpu", color_space="GRAY")

actual = fn(image)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image)))

assert_equal(actual, expected, rtol=0, atol=1)
Expand Down
13 changes: 13 additions & 0 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import torch
from torchvision import transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor

from ._transform import _RandomApplyTransform
from ._utils import query_chw


CVCUDA_AVAILABLE = _is_cvcuda_available()


class Grayscale(Transform):
"""Convert images or videos to grayscale.

Expand All @@ -22,6 +26,9 @@ class Grayscale(Transform):

_v1_transform_cls = _transforms.Grayscale

if CVCUDA_AVAILABLE:
_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(self, num_output_channels: int = 1):
super().__init__()
self.num_output_channels = num_output_channels
Expand All @@ -44,6 +51,9 @@ class RandomGrayscale(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomGrayscale

if CVCUDA_AVAILABLE:
_transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,)

def __init__(self, p: float = 0.1) -> None:
super().__init__(p=p)

Expand All @@ -62,6 +72,9 @@ class RGB(Transform):
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions
"""

if CVCUDA_AVAILABLE:
_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(self):
super().__init__()

Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor


def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]:
Expand Down Expand Up @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor))
}
if not chws:
raise TypeError("No image or video was found in the sample")
Expand All @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
tv_tensors.KeyPoints,
_is_cvcuda_tensor,
),
)
}
Expand Down
69 changes: 68 additions & 1 deletion torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING

import PIL.Image
import torch
from torch.nn.functional import conv2d
Expand All @@ -9,7 +11,15 @@

from ._misc import _num_value_bits, to_dtype_image
from ._type_conversion import pil_to_tensor, to_pil_image
from ._utils import _get_kernel, _register_kernel_internal
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal


CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda() # noqa: F811


def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
Expand Down Expand Up @@ -63,6 +73,38 @@ def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int
return _FP.to_grayscale(image, num_output_channels=num_output_channels)


def _rgb_to_grayscale_image_cvcuda(
image: "cvcuda.Tensor",
num_output_channels: int = 1,
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()

if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")

if num_output_channels == 3:
raise ValueError("num_output_channels must be 1 for CV-CUDA, got 3.")

if image.shape[3] == 1:
# if we already have a single channel, just clone the tensor
# we will use copymakeborder since CV-CUDA has no native clone
return cvcuda.copymakeborder(
image,
border_mode=cvcuda.Border.CONSTANT,
border_value=[0],
top=0,
left=0,
bottom=0,
right=0,
)

return cvcuda.cvtcolor(image, cvcuda.ColorConversion.RGB2GRAY)


if CVCUDA_AVAILABLE:
_register_kernel_internal(rgb_to_grayscale, _import_cvcuda().Tensor)(_rgb_to_grayscale_image_cvcuda)


def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.RGB` for details."""
if torch.jit.is_scripting():
Expand All @@ -89,6 +131,31 @@ def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return image.convert(mode="RGB")


def _grayscale_to_rgb_image_cvcuda(
image: "cvcuda.Tensor",
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()

if image.shape[3] == 3:
# if we already have RGB channels, just clone the tensor
# we will use copymakeborder since CV-CUDA has no native clone
return cvcuda.copymakeborder(
image,
border_mode=cvcuda.Border.CONSTANT,
border_value=[0],
top=0,
left=0,
bottom=0,
right=0,
)

return cvcuda.cvtcolor(image, cvcuda.ColorConversion.GRAY2RGB)


if CVCUDA_AVAILABLE:
_register_kernel_internal(grayscale_to_rgb, _import_cvcuda().Tensor)(_grayscale_to_rgb_image_cvcuda)


def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
Expand Down
22 changes: 21 additions & 1 deletion torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]:
return get_dimensions_image(video)


def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:
# CV-CUDA tensor is always in NHWC layout
# get_dimensions is CHW
return [image.shape[3], image.shape[1], image.shape[2]]


if CVCUDA_AVAILABLE:
_register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda)


def get_num_channels(inpt: torch.Tensor) -> int:
if torch.jit.is_scripting():
return get_num_channels_image(inpt)
Expand Down Expand Up @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int:
get_image_num_channels = get_num_channels


def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int:
# CV-CUDA tensor is always in NHWC layout
# get_num_channels is C
return image.shape[3]


if CVCUDA_AVAILABLE:
_register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda)


def get_size(inpt: torch.Tensor) -> list[int]:
if torch.jit.is_scripting():
return get_size_image(inpt)
Expand Down Expand Up @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:


if CVCUDA_AVAILABLE:
_get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda)
_register_kernel_internal(get_size, _import_cvcuda().Tensor)(get_size_image_cvcuda)


@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)
Expand Down