Skip to content

Commit 44db71c

Browse files
committed
implement additional cvcuda infra for all branches to avoid duplicate setup
1 parent 617079d commit 44db71c

File tree

8 files changed

+68
-10
lines changed

8 files changed

+68
-10
lines changed

torchvision/transforms/v2/_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
1212
from torchvision.utils import _log_api_usage_once
1313

14-
from .functional._utils import _get_kernel
14+
from .functional._utils import _get_kernel, is_cvcuda_tensor
1515

1616

1717
class Transform(nn.Module):
@@ -23,7 +23,7 @@ class Transform(nn.Module):
2323

2424
# Class attribute defining transformed types. Other types are passed-through without any transformation
2525
# We support both Types and callables that are able to do further checks on the type of the input.
26-
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
26+
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)
2727

2828
def __init__(self) -> None:
2929
super().__init__()

torchvision/transforms/v2/_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torchvision._utils import sequence_to_str
1616

1717
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
18-
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
18+
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor
1919
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
2020

2121

@@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
207207
tv_tensors.Mask,
208208
tv_tensors.BoundingBoxes,
209209
tv_tensors.KeyPoints,
210+
is_cvcuda_tensor,
210211
),
211212
)
212213
}

torchvision/transforms/v2/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torchvision.transforms import InterpolationMode # usort: skip
22

3-
from ._utils import is_pure_tensor, register_kernel # usort: skip
3+
from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip
44

55
from ._meta import (
66
clamp_bounding_boxes,

torchvision/transforms/v2/functional/_augment.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
from typing import TYPE_CHECKING
23

34
import PIL.Image
45

@@ -8,7 +9,15 @@
89
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
910
from torchvision.utils import _log_api_usage_once
1011

11-
from ._utils import _get_kernel, _register_kernel_internal
12+
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal
13+
14+
15+
CVCUDA_AVAILABLE = _is_cvcuda_available()
16+
17+
if TYPE_CHECKING:
18+
import cvcuda # type: ignore[import-not-found]
19+
if CVCUDA_AVAILABLE:
20+
cvcuda = _import_cvcuda() # noqa: F811
1221

1322

1423
def erase(

torchvision/transforms/v2/functional/_color.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import TYPE_CHECKING
2+
13
import PIL.Image
24
import torch
35
from torch.nn.functional import conv2d
@@ -9,7 +11,15 @@
911

1012
from ._misc import _num_value_bits, to_dtype_image
1113
from ._type_conversion import pil_to_tensor, to_pil_image
12-
from ._utils import _get_kernel, _register_kernel_internal
14+
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal
15+
16+
17+
CVCUDA_AVAILABLE = _is_cvcuda_available()
18+
19+
if TYPE_CHECKING:
20+
import cvcuda # type: ignore[import-not-found]
21+
if CVCUDA_AVAILABLE:
22+
cvcuda = _import_cvcuda() # noqa: F811
1323

1424

1525
def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numbers
33
import warnings
44
from collections.abc import Sequence
5-
from typing import Any, Optional, Union
5+
from typing import Any, Optional, TYPE_CHECKING, Union
66

77
import PIL.Image
88
import torch
@@ -26,7 +26,22 @@
2626

2727
from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format
2828

29-
from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
29+
from ._utils import (
30+
_FillTypeJIT,
31+
_get_kernel,
32+
_import_cvcuda,
33+
_is_cvcuda_available,
34+
_register_five_ten_crop_kernel_internal,
35+
_register_kernel_internal,
36+
)
37+
38+
39+
CVCUDA_AVAILABLE = _is_cvcuda_available()
40+
41+
if TYPE_CHECKING:
42+
import cvcuda # type: ignore[import-not-found]
43+
if CVCUDA_AVAILABLE:
44+
cvcuda = _import_cvcuda() # noqa: F811
3045

3146

3247
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:

torchvision/transforms/v2/functional/_misc.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Optional
2+
from typing import Optional, TYPE_CHECKING
33

44
import PIL.Image
55
import torch
@@ -13,7 +13,14 @@
1313

1414
from ._meta import _convert_bounding_box_format
1515

16-
from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
16+
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor
17+
18+
CVCUDA_AVAILABLE = _is_cvcuda_available()
19+
20+
if TYPE_CHECKING:
21+
import cvcuda # type: ignore[import-not-found]
22+
if CVCUDA_AVAILABLE:
23+
cvcuda = _import_cvcuda() # noqa: F811
1724

1825

1926
def normalize(

torchvision/transforms/v2/functional/_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,19 @@ def _is_cvcuda_available():
169169
return True
170170
except ImportError:
171171
return False
172+
173+
174+
def is_cvcuda_tensor(inpt: Any) -> bool:
175+
"""
176+
Check if the input is a CVCUDA tensor.
177+
178+
Args:
179+
inpt: The input to check.
180+
181+
Returns:
182+
True if the input is a CV-CUDA tensor, False otherwise.
183+
"""
184+
if _is_cvcuda_available():
185+
cvcuda = _import_cvcuda()
186+
return isinstance(inpt, cvcuda.Tensor)
187+
return False

0 commit comments

Comments
 (0)