Skip to content

Commit 31e08e4

Browse files
committed
begin work on finalizing the crop PR to include five and ten crop, adhere to new PR reviews for flip
1 parent ee626ae commit 31e08e4

File tree

5 files changed

+179
-21
lines changed

5 files changed

+179
-21
lines changed

test/common_utils.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@
2020
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2121
from torchvision import io, tv_tensors
2222
from torchvision.transforms._functional_tensor import _max_value as get_max_value
23-
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
23+
from torchvision.transforms.v2.functional import (
24+
cvcuda_to_tensor,
25+
is_cvcuda_tensor,
26+
to_cvcuda_tensor,
27+
to_image,
28+
to_pil_image,
29+
)
2430
from torchvision.utils import _Image_fromarray
2531

2632

@@ -275,6 +281,17 @@ def combinations_grid(**kwargs):
275281
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
276282

277283

284+
def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor:
285+
tensor = cvcuda_to_tensor(tensor)
286+
if tensor.ndim != 4:
287+
raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.")
288+
if tensor.shape[0] != 1:
289+
raise ValueError(
290+
f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}."
291+
)
292+
return tensor.squeeze(0).cpu()
293+
294+
278295
class ImagePair(TensorLikePair):
279296
def __init__(
280297
self,
@@ -284,8 +301,17 @@ def __init__(
284301
mae=False,
285302
**other_parameters,
286303
):
287-
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
288-
actual, expected = (to_image(input) for input in [actual, expected])
304+
# Convert PIL images to tv_tensors.Image (regardless of what the other is)
305+
if isinstance(actual, PIL.Image.Image):
306+
actual = to_image(actual)
307+
if isinstance(expected, PIL.Image.Image):
308+
expected = to_image(expected)
309+
310+
# attempt to convert CV-CUDA tensors to torch tensors
311+
if is_cvcuda_tensor(actual):
312+
actual = cvcuda_to_pil_compatible_tensor(actual)
313+
if is_cvcuda_tensor(expected):
314+
expected = cvcuda_to_pil_compatible_tensor(expected)
289315

290316
super().__init__(actual, expected, **other_parameters)
291317
self.mae = mae
@@ -401,7 +427,6 @@ def make_image_pil(*args, **kwargs):
401427

402428

403429
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
404-
# explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4)
405430
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))
406431

407432

test/test_transforms_v2.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
assert_equal,
2525
cache,
2626
cpu_and_cuda,
27+
cvcuda_to_pil_compatible_tensor,
2728
freeze_rng_state,
2829
ignore_jit_no_profile_information_warning,
2930
make_bounding_boxes,
@@ -3624,10 +3625,7 @@ def test_transform_image_correctness(self, param, value, seed, make_input):
36243625
torch.manual_seed(seed)
36253626

36263627
if make_input == make_image_cvcuda:
3627-
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
3628-
actual = actual.squeeze(0)
3629-
image = F.cvcuda_to_tensor(image).to(device="cpu")
3630-
image = image.squeeze(0)
3628+
image = cvcuda_to_pil_compatible_tensor(image)
36313629

36323630
expected = F.to_image(transform(F.to_pil_image(image)))
36333631

@@ -4995,10 +4993,7 @@ def test_image_correctness(self, output_size, make_input, fn):
49954993
actual = fn(image, output_size)
49964994

49974995
if make_input == make_image_cvcuda:
4998-
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
4999-
actual = actual.squeeze(0)
5000-
image = F.cvcuda_to_tensor(image).to(device="cpu")
5001-
image = image.squeeze(0)
4996+
image = cvcuda_to_pil_compatible_tensor(image)
50024997

50034998
expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size))
50044999

@@ -6274,7 +6269,15 @@ def wrapper(*args, **kwargs):
62746269

62756270
@pytest.mark.parametrize(
62766271
"make_input",
6277-
[make_image_tensor, make_image_pil, make_image, make_video],
6272+
[
6273+
make_image_tensor,
6274+
make_image_pil,
6275+
make_image,
6276+
make_video,
6277+
pytest.param(
6278+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6279+
),
6280+
],
62786281
)
62796282
@pytest.mark.parametrize("functional", [F.five_crop, F.ten_crop])
62806283
def test_functional(self, make_input, functional):
@@ -6292,13 +6295,27 @@ def test_functional(self, make_input, functional):
62926295
(F.five_crop, F._geometry._five_crop_image_pil, PIL.Image.Image),
62936296
(F.five_crop, F.five_crop_image, tv_tensors.Image),
62946297
(F.five_crop, F.five_crop_video, tv_tensors.Video),
6298+
pytest.param(
6299+
F.five_crop,
6300+
F._geometry._five_crop_cvcuda,
6301+
"cvcuda.Tensor",
6302+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
6303+
),
62956304
(F.ten_crop, F.ten_crop_image, torch.Tensor),
62966305
(F.ten_crop, F._geometry._ten_crop_image_pil, PIL.Image.Image),
62976306
(F.ten_crop, F.ten_crop_image, tv_tensors.Image),
62986307
(F.ten_crop, F.ten_crop_video, tv_tensors.Video),
6308+
pytest.param(
6309+
F.ten_crop,
6310+
F._geometry._ten_crop_cvcuda,
6311+
"cvcuda.Tensor",
6312+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
6313+
),
62996314
],
63006315
)
63016316
def test_functional_signature(self, functional, kernel, input_type):
6317+
if input_type == "cvcuda.Tensor":
6318+
input_type = _import_cvcuda().Tensor
63026319
check_functional_kernel_signature_match(functional, kernel=kernel, input_type=input_type)
63036320

63046321
class _TransformWrapper(nn.Module):
@@ -6320,7 +6337,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
63206337

63216338
@pytest.mark.parametrize(
63226339
"make_input",
6323-
[make_image_tensor, make_image_pil, make_image, make_video],
6340+
[
6341+
make_image_tensor,
6342+
make_image_pil,
6343+
make_image,
6344+
make_video,
6345+
pytest.param(
6346+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6347+
),
6348+
],
63246349
)
63256350
@pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop])
63266351
def test_transform(self, make_input, transform_cls):
@@ -6338,29 +6363,55 @@ def test_transform_error(self, make_input, transform_cls):
63386363
with pytest.raises(TypeError, match="not supported"):
63396364
transform(make_input(self.INPUT_SIZE))
63406365

6366+
@pytest.mark.parametrize(
6367+
"make_input",
6368+
[
6369+
make_image,
6370+
pytest.param(
6371+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6372+
),
6373+
],
6374+
)
63416375
@pytest.mark.parametrize("fn", [F.five_crop, transform_cls_to_functional(transforms.FiveCrop)])
6342-
def test_correctness_image_five_crop(self, fn):
6343-
image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
6376+
def test_correctness_image_five_crop(self, make_input, fn):
6377+
image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
63446378

63456379
actual = fn(image, size=self.OUTPUT_SIZE)
6380+
6381+
if make_input is make_image_cvcuda:
6382+
image = cvcuda_to_pil_compatible_tensor(image)
6383+
63466384
expected = F.five_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE)
63476385

63486386
assert isinstance(actual, tuple)
63496387
assert_equal(actual, [F.to_image(e) for e in expected])
63506388

6389+
@pytest.mark.parametrize(
6390+
"make_input",
6391+
[
6392+
make_image,
6393+
pytest.param(
6394+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6395+
),
6396+
],
6397+
)
63516398
@pytest.mark.parametrize("fn_or_class", [F.ten_crop, transforms.TenCrop])
63526399
@pytest.mark.parametrize("vertical_flip", [False, True])
6353-
def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip):
6400+
def test_correctness_image_ten_crop(self, make_input, fn_or_class, vertical_flip):
63546401
if fn_or_class is transforms.TenCrop:
63556402
fn = transform_cls_to_functional(fn_or_class, size=self.OUTPUT_SIZE, vertical_flip=vertical_flip)
63566403
kwargs = dict()
63576404
else:
63586405
fn = fn_or_class
63596406
kwargs = dict(size=self.OUTPUT_SIZE, vertical_flip=vertical_flip)
63606407

6361-
image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
6408+
image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
63626409

63636410
actual = fn(image, **kwargs)
6411+
6412+
if make_input is make_image_cvcuda:
6413+
image = cvcuda_to_pil_compatible_tensor(image)
6414+
63646415
expected = F.ten_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE, vertical_flip=vertical_flip)
63656416

63666417
assert isinstance(actual, tuple)

torchvision/transforms/v2/_geometry.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
get_bounding_boxes,
2727
has_all,
2828
has_any,
29+
is_cvcuda_tensor,
2930
is_pure_tensor,
3031
query_size,
3132
)
@@ -186,6 +187,8 @@ class CenterCrop(Transform):
186187

187188
_v1_transform_cls = _transforms.CenterCrop
188189

190+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
191+
189192
def __init__(self, size: Union[int, Sequence[int]]):
190193
super().__init__()
191194
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
@@ -352,6 +355,8 @@ class FiveCrop(Transform):
352355

353356
_v1_transform_cls = _transforms.FiveCrop
354357

358+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
359+
355360
def __init__(self, size: Union[int, Sequence[int]]) -> None:
356361
super().__init__()
357362
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
@@ -396,6 +401,8 @@ class TenCrop(Transform):
396401

397402
_v1_transform_cls = _transforms.TenCrop
398403

404+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
405+
399406
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
400407
super().__init__()
401408
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
@@ -803,6 +810,8 @@ class RandomCrop(Transform):
803810

804811
_v1_transform_cls = _transforms.RandomCrop
805812

813+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
814+
806815
def _extract_params_for_v1_transform(self) -> dict[str, Any]:
807816
params = super()._extract_params_for_v1_transform()
808817

@@ -1113,6 +1122,8 @@ class RandomIoUCrop(Transform):
11131122
Default, 40.
11141123
"""
11151124

1125+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
1126+
11161127
def __init__(
11171128
self,
11181129
min_scale: float = 0.3,

torchvision/transforms/v2/_transform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from torch import nn
99
from torch.utils._pytree import tree_flatten, tree_unflatten
1010
from torchvision import tv_tensors
11-
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
11+
from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor
1212
from torchvision.utils import _log_api_usage_once
1313

14-
from .functional._utils import _get_kernel, is_cvcuda_tensor
14+
from .functional._utils import _get_kernel
1515

1616

1717
class Transform(nn.Module):
@@ -23,7 +23,7 @@ class Transform(nn.Module):
2323

2424
# Class attribute defining transformed types. Other types are passed-through without any transformation
2525
# We support both Types and callables that are able to do further checks on the type of the input.
26-
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)
26+
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
2727

2828
def __init__(self) -> None:
2929
super().__init__()

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,14 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
142142
return horizontal_flip_image(video)
143143

144144

145+
def _horizontal_flip_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
146+
return _import_cvcuda().flip(image, flipCode=1)
147+
148+
149+
if CVCUDA_AVAILABLE:
150+
_register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_cvcuda)
151+
152+
145153
def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
146154
"""See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details."""
147155
if torch.jit.is_scripting():
@@ -230,6 +238,14 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
230238
return vertical_flip_image(video)
231239

232240

241+
def _vertical_flip_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
242+
return _import_cvcuda().flip(image, flipCode=0)
243+
244+
245+
if CVCUDA_AVAILABLE:
246+
_register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_cvcuda)
247+
248+
233249
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
234250
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
235251
hflip = horizontal_flip
@@ -3003,6 +3019,29 @@ def five_crop_video(
30033019
return five_crop_image(video, size)
30043020

30053021

3022+
def _five_crop_cvcuda(
3023+
image: "cvcuda.Tensor",
3024+
size: list[int],
3025+
) -> tuple["cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor"]:
3026+
crop_height, crop_width = _parse_five_crop_size(size)
3027+
image_height, image_width = image.shape[-2:]
3028+
3029+
if crop_width > image_width or crop_height > image_height:
3030+
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
3031+
3032+
tl = _crop_cvcuda(image, 0, 0, crop_height, crop_width)
3033+
tr = _crop_cvcuda(image, 0, image_width - crop_height, crop_width, crop_height)
3034+
bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_width, crop_height)
3035+
br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_width, crop_height)
3036+
center = _center_crop_cvcuda(image, [crop_height, crop_width])
3037+
3038+
return tl, tr, bl, br, center
3039+
3040+
3041+
if CVCUDA_AVAILABLE:
3042+
_register_kernel_internal(five_crop, _import_cvcuda().Tensor)(_five_crop_cvcuda)
3043+
3044+
30063045
def ten_crop(
30073046
inpt: torch.Tensor, size: list[int], vertical_flip: bool = False
30083047
) -> tuple[
@@ -3098,3 +3137,35 @@ def ten_crop_video(
30983137
torch.Tensor,
30993138
]:
31003139
return ten_crop_image(video, size, vertical_flip=vertical_flip)
3140+
3141+
3142+
def _ten_crop_cvcuda(
3143+
image: "cvcuda.Tensor",
3144+
size: list[int],
3145+
vertical_flip: bool = False,
3146+
) -> tuple[
3147+
"cvcuda.Tensor",
3148+
"cvcuda.Tensor",
3149+
"cvcuda.Tensor",
3150+
"cvcuda.Tensor",
3151+
"cvcuda.Tensor",
3152+
"cvcuda.Tensor",
3153+
"cvcuda.Tensor",
3154+
"cvcuda.Tensor",
3155+
"cvcuda.Tensor",
3156+
"cvcuda.Tensor",
3157+
]:
3158+
non_flipped = _five_crop_cvcuda(image, size)
3159+
3160+
if vertical_flip:
3161+
image = _vertical_flip_cvcuda(image)
3162+
else:
3163+
image = _horizontal_flip_cvcuda(image)
3164+
3165+
flipped = _five_crop_cvcuda(image, size)
3166+
3167+
return non_flipped + flipped
3168+
3169+
3170+
if CVCUDA_AVAILABLE:
3171+
_register_kernel_internal(ten_crop, _import_cvcuda().Tensor)(_ten_crop_cvcuda)

0 commit comments

Comments
 (0)