Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -41,7 +42,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -3354,6 +3354,9 @@ def test_kernel_video(self):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
),
],
)
def test_functional(self, make_input):
Expand All @@ -3369,9 +3372,16 @@ def test_functional(self, make_input):
(F.elastic_mask, tv_tensors.Mask),
(F.elastic_video, tv_tensors.Video),
(F.elastic_keypoints, tv_tensors.KeyPoints),
pytest.param(
F._geometry._elastic_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._geometry._elastic_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.elastic, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
Expand All @@ -3384,6 +3394,9 @@ def test_functional_signature(self, kernel, input_type):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
),
],
)
def test_displacement_error(self, make_input):
Expand All @@ -3405,6 +3418,9 @@ def test_displacement_error(self, make_input):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
),
],
)
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
Expand All @@ -3422,6 +3438,37 @@ def test_transform(self, make_input, size, device):
check_v1_compatibility=check_v1_compatibility,
)

@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available")
@needs_cuda
@pytest.mark.parametrize(
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
)
def test_image_cvcuda_correctness(self, interpolation):
image = make_image_cvcuda(dtype=torch.uint8)
displacement = self._make_displacement(image)

result = F._geometry._elastic_image_cvcuda(image, displacement=displacement, interpolation=interpolation)
result = F.cvcuda_to_tensor(result)

expected = F._geometry.elastic_image(
F.cvcuda_to_tensor(image), displacement=displacement, interpolation=interpolation
)

# mainly for checking properties (outside pixel values) are correct
# see note below on pixel-value differences
assert_close(result, expected, atol=get_max_value(torch.uint8), rtol=0)

# visually, the results are identical, however the underlying computations are different
# we can define an mae_threshold based on the interpolation mode
# the primary difference is along the borders where pixels appear to be shifted in location
# by up to 1, causing potentially up to a diff of 255 on a single pixel
# this could be because one has fill of 0 and CV-CUDA is shifted and has value with some color
# thresholds decrease as image size gets larger
# (640, 480) input, has 20.0, 13.0 respectively to pass
mae = (expected.float() - result.float()).abs().mean()
mae_threshold = 30.0 if interpolation is transforms.InterpolationMode.NEAREST else 20.0
assert mae < mae_threshold, f"MAE {mae} exceeds threshold"


class TestToPureTensor:
def test_correctness(self):
Expand Down
2 changes: 2 additions & 0 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,8 @@ class ElasticTransform(Transform):

_v1_transform_cls = _transforms.ElasticTransform

_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(
self,
alpha: Union[float, Sequence[float]] = 50.0,
Expand Down
3 changes: 2 additions & 1 deletion torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor


def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]:
Expand Down Expand Up @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
tv_tensors.KeyPoints,
_is_cvcuda_tensor,
),
)
}
Expand Down
78 changes: 78 additions & 0 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Sequence
from typing import Any, Optional, TYPE_CHECKING, Union

import numpy as np
import PIL.Image
import torch
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
Expand All @@ -28,6 +29,7 @@

from ._utils import (
_FillTypeJIT,
_get_cvcuda_interp,
_get_kernel,
_import_cvcuda,
_is_cvcuda_available,
Expand Down Expand Up @@ -2529,6 +2531,82 @@ def elastic_video(
return elastic_image(video, displacement, interpolation=interpolation, fill=fill)


def _elastic_image_cvcuda(
image: "cvcuda.Tensor",
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None,
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()

if not isinstance(displacement, torch.Tensor):
raise TypeError("Argument displacement should be a Tensor")

batch_size, height, width, num_channels = image.shape
device = torch.device("cuda")
dtype = torch.float32

expected_shape = (1, height, width, 2)
if expected_shape != displacement.shape:
raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")

# cvcuda.remap only supports uint8 for 3-channel images, float32 for 1-channel
input_dtype = image.dtype
if num_channels == 3 and input_dtype != cvcuda.Type.U8:
raise ValueError(f"cvcuda.remap requires uint8 dtype for 3-channel images, but got {input_dtype}")
elif num_channels == 1 and input_dtype != cvcuda.Type.F32:
raise ValueError(f"cvcuda.remap requires float32 dtype for 1-channel images, but got {input_dtype}")

interp = _get_cvcuda_interp(interpolation)

# Build normalized grid: identity + displacement
# _create_identity_grid returns (1, H, W, 2) with values in [-1, 1]
identity_grid = _create_identity_grid((height, width), device=device, dtype=dtype)
grid = identity_grid.add_(displacement.to(dtype=dtype, device=device))

# Convert normalized grid [-1, 1] to absolute pixel coordinates [0, width-1], [0, height-1]
# grid[..., 0] is x (horizontal), grid[..., 1] is y (vertical)
map_x = (grid[..., 0] + 1) * (width - 1) / 2.0
map_y = (grid[..., 1] + 1) * (height - 1) / 2.0

# Stack into (1, H, W, 2) map tensor
pixel_map = torch.stack([map_x, map_y], dim=-1)

# Expand map for batch if needed
if batch_size > 1:
pixel_map = pixel_map.expand(batch_size, -1, -1, -1)

# Create cvcuda map tensor (NHWC layout with 2 channels for x,y)
cv_map = cvcuda.as_tensor(pixel_map.contiguous(), "NHWC")

border_mode = cvcuda.Border.CONSTANT
if fill is None:
border_value = np.array([], dtype=np.float32)
elif isinstance(fill, (int, float)):
border_value = np.array([fill], dtype=np.float32)
elif isinstance(fill, (list, tuple)):
border_value = np.array(fill, dtype=np.float32)
else:
border_value = np.array([], dtype=np.float32)

output = cvcuda.remap(
image,
cv_map,
src_interp=interp,
map_interp=cvcuda.Interp.LINEAR,
map_type=cvcuda.Remap.ABSOLUTE,
align_corners=False,
border=border_mode,
border_value=border_value,
)

return output


if CVCUDA_AVAILABLE:
_register_kernel_internal(elastic, _import_cvcuda().Tensor)(_elastic_image_cvcuda)


def center_crop(inpt: torch.Tensor, output_size: list[int]) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.RandomCrop` for details."""
if torch.jit.is_scripting():
Expand Down
48 changes: 47 additions & 1 deletion torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import functools
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union

import torch
from torchvision import tv_tensors
from torchvision.transforms.functional import InterpolationMode

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]

_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[list[float]]
Expand Down Expand Up @@ -177,3 +181,45 @@ def _is_cvcuda_tensor(inpt: Any) -> bool:
return isinstance(inpt, cvcuda.Tensor)
except ImportError:
return False


_interpolation_mode_to_cvcuda_interp: dict[InterpolationMode | str | int, "cvcuda.Interp"] = {}


def _get_cvcuda_interp(interpolation: InterpolationMode | str | int) -> "cvcuda.Interp":
"""
Get the CV-CUDA interpolation mode for a given interpolation mode.

CV-CUDA has the two following differences (evaluated in tests) comapred to TorchVision/PIL:
1. CV-CUDA does not have a match for NEAREST, its Interp.NEAREST is actually NEAREST_EXACT
Since we need to do interpolation, we will map NEAREST to Interp.NEAREST (which is NEAREST_EXACT)
2. BICUBIC interpolation method is different compared to TorchVision/PIL, algorithmic difference
"""
if len(_interpolation_mode_to_cvcuda_interp) == 0:
cvcuda = _import_cvcuda()
_interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST_EXACT] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp[InterpolationMode.BILINEAR] = cvcuda.Interp.LINEAR
_interpolation_mode_to_cvcuda_interp[InterpolationMode.BICUBIC] = cvcuda.Interp.CUBIC
_interpolation_mode_to_cvcuda_interp[InterpolationMode.BOX] = cvcuda.Interp.BOX
_interpolation_mode_to_cvcuda_interp[InterpolationMode.HAMMING] = cvcuda.Interp.HAMMING
_interpolation_mode_to_cvcuda_interp[InterpolationMode.LANCZOS] = cvcuda.Interp.LANCZOS
_interpolation_mode_to_cvcuda_interp["nearest"] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp["nearest-exact"] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp["bilinear"] = cvcuda.Interp.LINEAR
_interpolation_mode_to_cvcuda_interp["bicubic"] = cvcuda.Interp.CUBIC
_interpolation_mode_to_cvcuda_interp["box"] = cvcuda.Interp.BOX
_interpolation_mode_to_cvcuda_interp["hamming"] = cvcuda.Interp.HAMMING
_interpolation_mode_to_cvcuda_interp["lanczos"] = cvcuda.Interp.LANCZOS
_interpolation_mode_to_cvcuda_interp[0] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp[2] = cvcuda.Interp.LINEAR
_interpolation_mode_to_cvcuda_interp[3] = cvcuda.Interp.CUBIC
_interpolation_mode_to_cvcuda_interp[4] = cvcuda.Interp.BOX
_interpolation_mode_to_cvcuda_interp[5] = cvcuda.Interp.HAMMING
_interpolation_mode_to_cvcuda_interp[1] = cvcuda.Interp.LANCZOS

interp = _interpolation_mode_to_cvcuda_interp.get(interpolation)
if interp is None:
raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA")

return interp