diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3ce603c3ed2..f34e6a730d8 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -3354,6 +3354,9 @@ def test_kernel_video(self): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available") + ), ], ) def test_functional(self, make_input): @@ -3369,9 +3372,16 @@ def test_functional(self, make_input): (F.elastic_mask, tv_tensors.Mask), (F.elastic_video, tv_tensors.Video), (F.elastic_keypoints, tv_tensors.KeyPoints), + pytest.param( + F._geometry._elastic_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._geometry._elastic_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.elastic, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( @@ -3384,6 +3394,9 @@ def test_functional_signature(self, kernel, input_type): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available") + ), ], ) def test_displacement_error(self, make_input): @@ -3405,6 +3418,9 @@ def test_displacement_error(self, make_input): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available") + ), ], ) # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image @@ -3422,6 +3438,37 @@ def test_transform(self, make_input, size, device): check_v1_compatibility=check_v1_compatibility, ) + @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available") + @needs_cuda + @pytest.mark.parametrize( + "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR] + ) + def test_image_cvcuda_correctness(self, interpolation): + image = make_image_cvcuda(dtype=torch.uint8) + displacement = self._make_displacement(image) + + result = F._geometry._elastic_image_cvcuda(image, displacement=displacement, interpolation=interpolation) + result = F.cvcuda_to_tensor(result) + + expected = F._geometry.elastic_image( + F.cvcuda_to_tensor(image), displacement=displacement, interpolation=interpolation + ) + + # mainly for checking properties (outside pixel values) are correct + # see note below on pixel-value differences + assert_close(result, expected, atol=get_max_value(torch.uint8), rtol=0) + + # visually, the results are identical, however the underlying computations are different + # we can define an mae_threshold based on the interpolation mode + # the primary difference is along the borders where pixels appear to be shifted in location + # by up to 1, causing potentially up to a diff of 255 on a single pixel + # this could be because one has fill of 0 and CV-CUDA is shifted and has value with some color + # thresholds decrease as image size gets larger + # (640, 480) input, has 20.0, 13.0 respectively to pass + mae = (expected.float() - result.float()).abs().mean() + mae_threshold = 30.0 if interpolation is transforms.InterpolationMode.NEAREST else 20.0 + assert mae < mae_threshold, f"MAE {mae} exceeds threshold" + class TestToPureTensor: def test_correctness(self): diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 96166e05e9a..7bc84112676 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -1045,6 +1045,8 @@ class ElasticTransform(Transform): _v1_transform_cls = _transforms.ElasticTransform + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) + def __init__( self, alpha: Union[float, Sequence[float]] = 50.0, diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..7274abaa861 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -16,7 +16,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + _is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 0e27218bc89..aa63e69934f 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from typing import Any, Optional, TYPE_CHECKING, Union +import numpy as np import PIL.Image import torch from torch.nn.functional import grid_sample, interpolate, pad as torch_pad @@ -28,6 +29,7 @@ from ._utils import ( _FillTypeJIT, + _get_cvcuda_interp, _get_kernel, _import_cvcuda, _is_cvcuda_available, @@ -2529,6 +2531,82 @@ def elastic_video( return elastic_image(video, displacement, interpolation=interpolation, fill=fill) +def _elastic_image_cvcuda( + image: "cvcuda.Tensor", + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: _FillTypeJIT = None, +) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + if not isinstance(displacement, torch.Tensor): + raise TypeError("Argument displacement should be a Tensor") + + batch_size, height, width, num_channels = image.shape + device = torch.device("cuda") + dtype = torch.float32 + + expected_shape = (1, height, width, 2) + if expected_shape != displacement.shape: + raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") + + # cvcuda.remap only supports uint8 for 3-channel images, float32 for 1-channel + input_dtype = image.dtype + if num_channels == 3 and input_dtype != cvcuda.Type.U8: + raise ValueError(f"cvcuda.remap requires uint8 dtype for 3-channel images, but got {input_dtype}") + elif num_channels == 1 and input_dtype != cvcuda.Type.F32: + raise ValueError(f"cvcuda.remap requires float32 dtype for 1-channel images, but got {input_dtype}") + + interp = _get_cvcuda_interp(interpolation) + + # Build normalized grid: identity + displacement + # _create_identity_grid returns (1, H, W, 2) with values in [-1, 1] + identity_grid = _create_identity_grid((height, width), device=device, dtype=dtype) + grid = identity_grid.add_(displacement.to(dtype=dtype, device=device)) + + # Convert normalized grid [-1, 1] to absolute pixel coordinates [0, width-1], [0, height-1] + # grid[..., 0] is x (horizontal), grid[..., 1] is y (vertical) + map_x = (grid[..., 0] + 1) * (width - 1) / 2.0 + map_y = (grid[..., 1] + 1) * (height - 1) / 2.0 + + # Stack into (1, H, W, 2) map tensor + pixel_map = torch.stack([map_x, map_y], dim=-1) + + # Expand map for batch if needed + if batch_size > 1: + pixel_map = pixel_map.expand(batch_size, -1, -1, -1) + + # Create cvcuda map tensor (NHWC layout with 2 channels for x,y) + cv_map = cvcuda.as_tensor(pixel_map.contiguous(), "NHWC") + + border_mode = cvcuda.Border.CONSTANT + if fill is None: + border_value = np.array([], dtype=np.float32) + elif isinstance(fill, (int, float)): + border_value = np.array([fill], dtype=np.float32) + elif isinstance(fill, (list, tuple)): + border_value = np.array(fill, dtype=np.float32) + else: + border_value = np.array([], dtype=np.float32) + + output = cvcuda.remap( + image, + cv_map, + src_interp=interp, + map_interp=cvcuda.Interp.LINEAR, + map_type=cvcuda.Remap.ABSOLUTE, + align_corners=False, + border=border_mode, + border_value=border_value, + ) + + return output + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(elastic, _import_cvcuda().Tensor)(_elastic_image_cvcuda) + + def center_crop(inpt: torch.Tensor, output_size: list[int]) -> torch.Tensor: """See :class:`~torchvision.transforms.v2.RandomCrop` for details.""" if torch.jit.is_scripting(): diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 11480b30ef9..b924bb16d38 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,9 +1,13 @@ import functools from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torchvision import tv_tensors +from torchvision.transforms.functional import InterpolationMode + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] _FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillTypeJIT = Optional[list[float]] @@ -177,3 +181,45 @@ def _is_cvcuda_tensor(inpt: Any) -> bool: return isinstance(inpt, cvcuda.Tensor) except ImportError: return False + + +_interpolation_mode_to_cvcuda_interp: dict[InterpolationMode | str | int, "cvcuda.Interp"] = {} + + +def _get_cvcuda_interp(interpolation: InterpolationMode | str | int) -> "cvcuda.Interp": + """ + Get the CV-CUDA interpolation mode for a given interpolation mode. + + CV-CUDA has the two following differences (evaluated in tests) comapred to TorchVision/PIL: + 1. CV-CUDA does not have a match for NEAREST, its Interp.NEAREST is actually NEAREST_EXACT + Since we need to do interpolation, we will map NEAREST to Interp.NEAREST (which is NEAREST_EXACT) + 2. BICUBIC interpolation method is different compared to TorchVision/PIL, algorithmic difference + """ + if len(_interpolation_mode_to_cvcuda_interp) == 0: + cvcuda = _import_cvcuda() + _interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST_EXACT] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp[InterpolationMode.BILINEAR] = cvcuda.Interp.LINEAR + _interpolation_mode_to_cvcuda_interp[InterpolationMode.BICUBIC] = cvcuda.Interp.CUBIC + _interpolation_mode_to_cvcuda_interp[InterpolationMode.BOX] = cvcuda.Interp.BOX + _interpolation_mode_to_cvcuda_interp[InterpolationMode.HAMMING] = cvcuda.Interp.HAMMING + _interpolation_mode_to_cvcuda_interp[InterpolationMode.LANCZOS] = cvcuda.Interp.LANCZOS + _interpolation_mode_to_cvcuda_interp["nearest"] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp["nearest-exact"] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp["bilinear"] = cvcuda.Interp.LINEAR + _interpolation_mode_to_cvcuda_interp["bicubic"] = cvcuda.Interp.CUBIC + _interpolation_mode_to_cvcuda_interp["box"] = cvcuda.Interp.BOX + _interpolation_mode_to_cvcuda_interp["hamming"] = cvcuda.Interp.HAMMING + _interpolation_mode_to_cvcuda_interp["lanczos"] = cvcuda.Interp.LANCZOS + _interpolation_mode_to_cvcuda_interp[0] = cvcuda.Interp.NEAREST + _interpolation_mode_to_cvcuda_interp[2] = cvcuda.Interp.LINEAR + _interpolation_mode_to_cvcuda_interp[3] = cvcuda.Interp.CUBIC + _interpolation_mode_to_cvcuda_interp[4] = cvcuda.Interp.BOX + _interpolation_mode_to_cvcuda_interp[5] = cvcuda.Interp.HAMMING + _interpolation_mode_to_cvcuda_interp[1] = cvcuda.Interp.LANCZOS + + interp = _interpolation_mode_to_cvcuda_interp.get(interpolation) + if interp is None: + raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA") + + return interp