Skip to content

Commit c035df1

Browse files
committed
add stanardized setup to main for easier updating of PRs and branches
1 parent e3dd700 commit c035df1

File tree

4 files changed

+43
-6
lines changed

4 files changed

+43
-6
lines changed

test/common_utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2121
from torchvision import io, tv_tensors
2222
from torchvision.transforms._functional_tensor import _max_value as get_max_value
23-
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
23+
from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
24+
from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available
2425
from torchvision.utils import _Image_fromarray
2526

2627

2728
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
2829
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
2930
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
31+
CVCUDA_AVAILABLE = _is_cvcuda_available()
3032
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
3133
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
3234
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
@@ -275,6 +277,17 @@ def combinations_grid(**kwargs):
275277
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
276278

277279

280+
def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor:
281+
tensor = cvcuda_to_tensor(tensor)
282+
if tensor.ndim != 4:
283+
raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.")
284+
if tensor.shape[0] != 1:
285+
raise ValueError(
286+
f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}."
287+
)
288+
return tensor.squeeze(0).cpu()
289+
290+
278291
class ImagePair(TensorLikePair):
279292
def __init__(
280293
self,
@@ -287,6 +300,11 @@ def __init__(
287300
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
288301
actual, expected = (to_image(input) for input in [actual, expected])
289302

303+
# handle check for CV-CUDA Tensors
304+
if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor):
305+
# Use the PIL compatible tensor, so we can always compare with PIL.Image.Image
306+
actual = cvcuda_to_pil_compatible_tensor(actual)
307+
290308
super().__init__(actual, expected, **other_parameters)
291309
self.mae = mae
292310

@@ -401,7 +419,6 @@ def make_image_pil(*args, **kwargs):
401419

402420

403421
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
404-
# explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4)
405422
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))
406423

407424

test/test_transforms_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torchvision.transforms.v2 as transforms
2222

2323
from common_utils import (
24+
assert_close,
2425
assert_equal,
2526
cache,
2627
cpu_and_cuda,
@@ -41,7 +42,6 @@
4142
)
4243

4344
from torch import nn
44-
from torch.testing import assert_close
4545
from torch.utils._pytree import tree_flatten, tree_map
4646
from torch.utils.data import DataLoader, default_collate
4747
from torchvision import tv_tensors

torchvision/transforms/v2/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
182182
chws = {
183183
tuple(get_dimensions(inpt))
184184
for inpt in flat_inputs
185-
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
185+
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor))
186186
}
187187
if not chws:
188188
raise TypeError("No image or video was found in the sample")

torchvision/transforms/v2/functional/_meta.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]:
5151
return get_dimensions_image(video)
5252

5353

54+
def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]:
55+
# CV-CUDA tensor is always in NHWC layout
56+
# get_dimensions is CHW
57+
return [image.shape[3], image.shape[1], image.shape[2]]
58+
59+
60+
if CVCUDA_AVAILABLE:
61+
_register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda)
62+
63+
5464
def get_num_channels(inpt: torch.Tensor) -> int:
5565
if torch.jit.is_scripting():
5666
return get_num_channels_image(inpt)
@@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int:
8797
get_image_num_channels = get_num_channels
8898

8999

100+
def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int:
101+
# CV-CUDA tensor is always in NHWC layout
102+
# get_num_channels is C
103+
return image.shape[3]
104+
105+
106+
if CVCUDA_AVAILABLE:
107+
_register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda)
108+
109+
90110
def get_size(inpt: torch.Tensor) -> list[int]:
91111
if torch.jit.is_scripting():
92112
return get_size_image(inpt)
@@ -114,7 +134,7 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]:
114134
return [height, width]
115135

116136

117-
def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:
137+
def _get_size_cvcuda(image: "cvcuda.Tensor") -> list[int]:
118138
"""Get size of `cvcuda.Tensor` with NHWC layout."""
119139
hw = list(image.shape[-3:-1])
120140
ndims = len(hw)
@@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:
125145

126146

127147
if CVCUDA_AVAILABLE:
128-
_get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda)
148+
_register_kernel_internal(get_size, cvcuda.Tensor)(_get_size_cvcuda)
129149

130150

131151
@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)

0 commit comments

Comments
 (0)