Skip to content

Commit be9183a

Browse files
committed
rebase to main standards
1 parent c362120 commit be9183a

File tree

7 files changed

+62
-66
lines changed

7 files changed

+62
-66
lines changed

test/test_transforms_v2.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
assert_equal,
2626
cache,
2727
cpu_and_cuda,
28-
cvcuda_to_pil_compatible_tensor,
2928
freeze_rng_state,
3029
ignore_jit_no_profile_information_warning,
3130
make_bounding_boxes,
@@ -3374,14 +3373,14 @@ def test_functional(self, make_input):
33743373
(F.elastic_video, tv_tensors.Video),
33753374
(F.elastic_keypoints, tv_tensors.KeyPoints),
33763375
pytest.param(
3377-
F._geometry._elastic_cvcuda,
3378-
"cvcuda.Tensor",
3376+
F._geometry._elastic_image_cvcuda,
3377+
None,
33793378
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CV-CUDA not available"),
33803379
),
33813380
],
33823381
)
33833382
def test_functional_signature(self, kernel, input_type):
3384-
if input_type == "cvcuda.Tensor":
3383+
if kernel is F._geometry._elastic_image_cvcuda:
33853384
input_type = _import_cvcuda().Tensor
33863385
check_functional_kernel_signature_match(F.elastic, kernel=kernel, input_type=input_type)
33873386

torchvision/transforms/v2/_geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
is_pure_tensor,
3030
query_size,
3131
)
32-
from .functional._utils import is_cvcuda_tensor
3332

3433
CVCUDA_AVAILABLE = _is_cvcuda_available()
3534

@@ -1046,7 +1045,8 @@ class ElasticTransform(Transform):
10461045

10471046
_v1_transform_cls = _transforms.ElasticTransform
10481047

1049-
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
1048+
if CVCUDA_AVAILABLE:
1049+
_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)
10501050

10511051
def __init__(
10521052
self,

torchvision/transforms/v2/functional/_augment.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import io
2-
from typing import TYPE_CHECKING
32

43
import PIL.Image
54

@@ -9,15 +8,7 @@
98
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
109
from torchvision.utils import _log_api_usage_once
1110

12-
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal
13-
14-
15-
CVCUDA_AVAILABLE = _is_cvcuda_available()
16-
17-
if TYPE_CHECKING:
18-
import cvcuda # type: ignore[import-not-found]
19-
if CVCUDA_AVAILABLE:
20-
cvcuda = _import_cvcuda() # noqa: F811
11+
from ._utils import _get_kernel, _register_kernel_internal
2112

2213

2314
def erase(

torchvision/transforms/v2/functional/_color.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import TYPE_CHECKING
2-
31
import PIL.Image
42
import torch
53
from torch.nn.functional import conv2d
@@ -11,15 +9,7 @@
119

1210
from ._misc import _num_value_bits, to_dtype_image
1311
from ._type_conversion import pil_to_tensor, to_pil_image
14-
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal
15-
16-
17-
CVCUDA_AVAILABLE = _is_cvcuda_available()
18-
19-
if TYPE_CHECKING:
20-
import cvcuda # type: ignore[import-not-found]
21-
if CVCUDA_AVAILABLE:
22-
cvcuda = _import_cvcuda() # noqa: F811
12+
from ._utils import _get_kernel, _register_kernel_internal
2313

2414

2515
def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from ._utils import (
3131
_FillTypeJIT,
32+
_get_cvcuda_interp,
3233
_get_kernel,
3334
_import_cvcuda,
3435
_is_cvcuda_available,
@@ -2530,36 +2531,14 @@ def elastic_video(
25302531
return elastic_image(video, displacement, interpolation=interpolation, fill=fill)
25312532

25322533

2533-
if CVCUDA_AVAILABLE:
2534-
_cvcuda_interp = {
2535-
InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR,
2536-
"bilinear": cvcuda.Interp.LINEAR,
2537-
"linear": cvcuda.Interp.LINEAR,
2538-
2: cvcuda.Interp.LINEAR,
2539-
InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC,
2540-
"bicubic": cvcuda.Interp.CUBIC,
2541-
3: cvcuda.Interp.CUBIC,
2542-
InterpolationMode.NEAREST: cvcuda.Interp.NEAREST,
2543-
"nearest": cvcuda.Interp.NEAREST,
2544-
0: cvcuda.Interp.NEAREST,
2545-
InterpolationMode.BOX: cvcuda.Interp.BOX,
2546-
"box": cvcuda.Interp.BOX,
2547-
4: cvcuda.Interp.BOX,
2548-
InterpolationMode.HAMMING: cvcuda.Interp.HAMMING,
2549-
"hamming": cvcuda.Interp.HAMMING,
2550-
5: cvcuda.Interp.HAMMING,
2551-
InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS,
2552-
"lanczos": cvcuda.Interp.LANCZOS,
2553-
1: cvcuda.Interp.LANCZOS,
2554-
}
2555-
2556-
2557-
def _elastic_cvcuda(
2534+
def _elastic_image_cvcuda(
25582535
image: "cvcuda.Tensor",
25592536
displacement: torch.Tensor,
25602537
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
25612538
fill: _FillTypeJIT = None,
25622539
) -> "cvcuda.Tensor":
2540+
cvcuda = _import_cvcuda()
2541+
25632542
if not isinstance(displacement, torch.Tensor):
25642543
raise TypeError("Argument displacement should be a Tensor")
25652544

@@ -2578,9 +2557,7 @@ def _elastic_cvcuda(
25782557
elif num_channels == 1 and input_dtype != cvcuda.Type.F32:
25792558
raise ValueError(f"cvcuda.remap requires float32 dtype for 1-channel images, but got {input_dtype}")
25802559

2581-
interp = _cvcuda_interp.get(interpolation, cvcuda.Interp.LINEAR)
2582-
if interp is None:
2583-
raise ValueError(f"Invalid interpolation mode: {interpolation}")
2560+
interp = _get_cvcuda_interp(interpolation)
25842561

25852562
# Build normalized grid: identity + displacement
25862563
# _create_identity_grid returns (1, H, W, 2) with values in [-1, 1]
@@ -2627,7 +2604,7 @@ def _elastic_cvcuda(
26272604

26282605

26292606
if CVCUDA_AVAILABLE:
2630-
_elastic_cvcuda = _register_kernel_internal(elastic, cvcuda.Tensor)(_elastic_cvcuda)
2607+
_register_kernel_internal(elastic, _import_cvcuda().Tensor)(_elastic_image_cvcuda)
26312608

26322609

26332610
def center_crop(inpt: torch.Tensor, output_size: list[int]) -> torch.Tensor:

torchvision/transforms/v2/functional/_misc.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Optional, TYPE_CHECKING
2+
from typing import Optional
33

44
import PIL.Image
55
import torch
@@ -13,14 +13,7 @@
1313

1414
from ._meta import _convert_bounding_box_format
1515

16-
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor
17-
18-
CVCUDA_AVAILABLE = _is_cvcuda_available()
19-
20-
if TYPE_CHECKING:
21-
import cvcuda # type: ignore[import-not-found]
22-
if CVCUDA_AVAILABLE:
23-
cvcuda = _import_cvcuda() # noqa: F811
16+
from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
2417

2518

2619
def normalize(

torchvision/transforms/v2/functional/_utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import functools
22
from collections.abc import Sequence
3-
from typing import Any, Callable, Optional, Union
3+
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
44

55
import torch
66
from torchvision import tv_tensors
7+
from torchvision.transforms.functional import InterpolationMode
8+
9+
if TYPE_CHECKING:
10+
import cvcuda # type: ignore[import-not-found]
711

812
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
913
_FillTypeJIT = Optional[list[float]]
@@ -177,3 +181,45 @@ def _is_cvcuda_tensor(inpt: Any) -> bool:
177181
return isinstance(inpt, cvcuda.Tensor)
178182
except ImportError:
179183
return False
184+
185+
186+
_interpolation_mode_to_cvcuda_interp: dict[InterpolationMode | str | int, "cvcuda.Interp"] = {}
187+
188+
189+
def _populate_interpolation_mode_to_cvcuda_interp():
190+
cvcuda = _import_cvcuda()
191+
192+
global _interpolation_mode_to_cvcuda_interp
193+
194+
_interpolation_mode_to_cvcuda_interp = {
195+
InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR,
196+
"bilinear": cvcuda.Interp.LINEAR,
197+
"linear": cvcuda.Interp.LINEAR,
198+
2: cvcuda.Interp.LINEAR,
199+
InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC,
200+
"bicubic": cvcuda.Interp.CUBIC,
201+
3: cvcuda.Interp.CUBIC,
202+
InterpolationMode.NEAREST: cvcuda.Interp.NEAREST,
203+
"nearest": cvcuda.Interp.NEAREST,
204+
0: cvcuda.Interp.NEAREST,
205+
InterpolationMode.BOX: cvcuda.Interp.BOX,
206+
"box": cvcuda.Interp.BOX,
207+
4: cvcuda.Interp.BOX,
208+
InterpolationMode.HAMMING: cvcuda.Interp.HAMMING,
209+
"hamming": cvcuda.Interp.HAMMING,
210+
5: cvcuda.Interp.HAMMING,
211+
InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS,
212+
"lanczos": cvcuda.Interp.LANCZOS,
213+
1: cvcuda.Interp.LANCZOS,
214+
}
215+
216+
217+
def _get_cvcuda_interp(interpolation: InterpolationMode | str | int) -> "cvcuda.Interp":
218+
if len(_interpolation_mode_to_cvcuda_interp) == 0:
219+
_populate_interpolation_mode_to_cvcuda_interp()
220+
221+
interp = _interpolation_mode_to_cvcuda_interp.get(interpolation)
222+
if interp is None:
223+
raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA")
224+
225+
return interp

0 commit comments

Comments
 (0)