Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -41,7 +42,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -4727,6 +4727,9 @@ def test_kernel_video(self):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_functional(self, make_input):
Expand All @@ -4745,9 +4748,16 @@ def test_functional(self, make_input):
(F.pad_bounding_boxes, tv_tensors.BoundingBoxes),
(F.pad_mask, tv_tensors.Mask),
(F.pad_video, tv_tensors.Video),
pytest.param(
F._geometry._pad_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._geometry._pad_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.pad, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
Expand All @@ -4760,6 +4770,9 @@ def test_functional_signature(self, kernel, input_type):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_transform(self, make_input):
Expand All @@ -4784,6 +4797,15 @@ def test_transform_errors(self):
with pytest.raises(ValueError, match="Padding mode should be either"):
transforms.Pad(12, padding_mode="abc")

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
@pytest.mark.parametrize("padding", CORRECTNESS_PADDINGS)
@pytest.mark.parametrize(
("padding_mode", "fill"),
Expand All @@ -4793,12 +4815,16 @@ def test_transform_errors(self):
],
)
@pytest.mark.parametrize("fn", [F.pad, transform_cls_to_functional(transforms.Pad)])
def test_image_correctness(self, padding, padding_mode, fill, fn):
image = make_image(dtype=torch.uint8, device="cpu")
def test_image_correctness(self, make_input, padding, padding_mode, fill, fn):
image = make_input(dtype=torch.uint8, device="cpu")

fill = adapt_fill(fill, dtype=torch.uint8)

actual = fn(image, padding=padding, padding_mode=padding_mode, fill=fill)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.pad(F.to_pil_image(image), padding=padding, padding_mode=padding_mode, fill=fill))

assert_equal(actual, expected)
Expand Down
2 changes: 2 additions & 0 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ class Pad(Transform):

_v1_transform_cls = _transforms.Pad

_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def _extract_params_for_v1_transform(self) -> dict[str, Any]:
params = super()._extract_params_for_v1_transform()

Expand Down
33 changes: 33 additions & 0 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from ._utils import (
_FillTypeJIT,
_get_cvcuda_border_from_pad_mode,
_get_kernel,
_import_cvcuda,
_is_cvcuda_available,
Expand Down Expand Up @@ -1682,6 +1683,38 @@ def _pad_with_vector_fill(
_pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad)


def _pad_image_cvcuda(
image: "cvcuda.Tensor",
padding: list[int],
fill: Optional[Union[int, float, list[float]]] = None,
padding_mode: str = "constant",
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()

border_mode = _get_cvcuda_border_from_pad_mode(padding_mode)

if fill is None:
fill = 0
if isinstance(fill, (int, float)):
fill = [fill] * image.shape[3]

left, right, top, bottom = _parse_pad_padding(padding)

return cvcuda.copymakeborder(
image,
border_mode=border_mode,
border_value=fill,
top=top,
left=left,
bottom=bottom,
right=right,
)


if CVCUDA_AVAILABLE:
_register_kernel_internal(pad, _import_cvcuda().Tensor)(_pad_image_cvcuda)


@_register_kernel_internal(pad, tv_tensors.Mask)
def pad_mask(
mask: torch.Tensor,
Expand Down
24 changes: 23 additions & 1 deletion torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import functools
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union

import torch
from torchvision import tv_tensors

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]

_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[list[float]]

Expand Down Expand Up @@ -177,3 +180,22 @@ def _is_cvcuda_tensor(inpt: Any) -> bool:
return isinstance(inpt, cvcuda.Tensor)
except ImportError:
return False


_pad_mode_to_cvcuda_border: dict[str, "cvcuda.Border"] = {}


def _get_cvcuda_border_from_pad_mode(pad_mode: str) -> "cvcuda.Border":
if len(_pad_mode_to_cvcuda_border) == 0:
cvcuda = _import_cvcuda()
_pad_mode_to_cvcuda_border["constant"] = cvcuda.Border.CONSTANT
_pad_mode_to_cvcuda_border["reflect"] = cvcuda.Border.REFLECT101
_pad_mode_to_cvcuda_border["replicate"] = cvcuda.Border.REPLICATE
_pad_mode_to_cvcuda_border["edge"] = cvcuda.Border.REPLICATE
_pad_mode_to_cvcuda_border["symmetric"] = cvcuda.Border.REFLECT

border_mode = _pad_mode_to_cvcuda_border.get(pad_mode)
if border_mode is None:
raise ValueError(f"Pad mode {pad_mode} is not supported with CV-CUDA")

return border_mode