Skip to content

Commit e14e210

Browse files
committed
merge with main
2 parents e51dc7e + 6b56de1 commit e14e210

File tree

8 files changed

+159
-57
lines changed

8 files changed

+159
-57
lines changed

.github/scripts/setup-env.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ echo '::group::Install TorchVision'
8282
pip install -e . -v --no-build-isolation
8383
echo '::endgroup::'
8484

85+
if [[ "${CVCUDA:-}" == "1" ]]; then
86+
echo '::group::Install CV-CUDA'
87+
pip install --progress-bar=off cvcuda-cu12
88+
echo '::endgroup::'
89+
fi
90+
8591
echo '::group::Install torchvision-extra-decoders'
8692
# This can be done after torchvision was built
8793
if [[ "$(uname)" == "Linux" && "$(uname -m)" != "aarch64" ]]; then

.github/workflows/tests.yml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,36 @@ jobs:
4545
4646
./.github/scripts/unittest.sh
4747
48+
unittests-linux-cvcuda:
49+
strategy:
50+
matrix:
51+
python-version:
52+
- "3.10"
53+
runner: ["linux.g5.4xlarge.nvidia.gpu"]
54+
gpu-arch-type: ["cuda"]
55+
gpu-arch-version: ["12.6"]
56+
fail-fast: false
57+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
58+
permissions:
59+
id-token: write
60+
contents: read
61+
with:
62+
repository: pytorch/vision
63+
runner: ${{ matrix.runner }}
64+
gpu-arch-type: ${{ matrix.gpu-arch-type }}
65+
gpu-arch-version: ${{ matrix.gpu-arch-version }}
66+
timeout: 120
67+
test-infra-ref: main
68+
script: |
69+
set -euo pipefail
70+
71+
export PYTHON_VERSION=${{ matrix.python-version }}
72+
export GPU_ARCH_TYPE=${{ matrix.gpu-arch-type }}
73+
export GPU_ARCH_VERSION=${{ matrix.gpu-arch-version }}
74+
export CVCUDA="1"
75+
76+
./.github/scripts/unittest.sh
77+
4878
unittests-macos:
4979
strategy:
5080
matrix:

test/common_utils.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121
from torchvision import io, tv_tensors
2222
from torchvision.transforms._functional_tensor import _max_value as get_max_value
2323
from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
24-
from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available
24+
from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor
2525
from torchvision.utils import _Image_fromarray
2626

2727

2828
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
2929
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
3030
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
31-
CVCUDA_AVAILABLE = _is_cvcuda_available()
3231
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
3332
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
3433
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
@@ -277,17 +276,6 @@ def combinations_grid(**kwargs):
277276
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
278277

279278

280-
def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor:
281-
tensor = cvcuda_to_tensor(tensor)
282-
if tensor.ndim != 4:
283-
raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.")
284-
if tensor.shape[0] != 1:
285-
raise ValueError(
286-
f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}."
287-
)
288-
return tensor.squeeze(0).cpu()
289-
290-
291279
class ImagePair(TensorLikePair):
292280
def __init__(
293281
self,
@@ -297,13 +285,24 @@ def __init__(
297285
mae=False,
298286
**other_parameters,
299287
):
300-
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
301-
actual, expected = (to_image(input) for input in [actual, expected])
302-
303-
# handle check for CV-CUDA Tensors
304-
if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor):
305-
# Use the PIL compatible tensor, so we can always compare with PIL.Image.Image
306-
actual = cvcuda_to_pil_compatible_tensor(actual)
288+
# Convert PIL images to tv_tensors.Image (regardless of what the other is)
289+
if isinstance(actual, PIL.Image.Image):
290+
actual = to_image(actual)
291+
if isinstance(expected, PIL.Image.Image):
292+
expected = to_image(expected)
293+
294+
if _is_cvcuda_available():
295+
if _is_cvcuda_tensor(actual):
296+
actual = cvcuda_to_tensor(actual)
297+
# Remove batch dimension if it's 1 for easier comparison against 3D PIL images
298+
if actual.shape[0] == 1:
299+
actual = actual[0]
300+
actual = actual.cpu()
301+
if _is_cvcuda_tensor(expected):
302+
expected = cvcuda_to_tensor(expected)
303+
if expected.shape[0] == 1:
304+
expected = expected[0]
305+
expected = expected.cpu()
307306

308307
super().__init__(actual, expected, **other_parameters)
309308
self.mae = mae
@@ -559,5 +558,9 @@ def ignore_jit_no_profile_information_warning():
559558
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
560559
# them.
561560
with warnings.catch_warnings():
562-
warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning)
561+
warnings.filterwarnings(
562+
"ignore",
563+
message=re.escape("operator() profile_node %"),
564+
category=UserWarning,
565+
)
563566
yield

test/test_transforms_v2.py

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,10 @@ def test_kernel_video(self):
12411241
make_image_tensor,
12421242
make_image_pil,
12431243
make_image,
1244+
pytest.param(
1245+
make_image_cvcuda,
1246+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1247+
),
12441248
make_bounding_boxes,
12451249
make_segmentation_mask,
12461250
make_video,
@@ -1256,13 +1260,20 @@ def test_functional(self, make_input):
12561260
(F.horizontal_flip_image, torch.Tensor),
12571261
(F._geometry._horizontal_flip_image_pil, PIL.Image.Image),
12581262
(F.horizontal_flip_image, tv_tensors.Image),
1263+
pytest.param(
1264+
F._geometry._horizontal_flip_image_cvcuda,
1265+
None,
1266+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1267+
),
12591268
(F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
12601269
(F.horizontal_flip_mask, tv_tensors.Mask),
12611270
(F.horizontal_flip_video, tv_tensors.Video),
12621271
(F.horizontal_flip_keypoints, tv_tensors.KeyPoints),
12631272
],
12641273
)
12651274
def test_functional_signature(self, kernel, input_type):
1275+
if kernel is F._geometry._horizontal_flip_image_cvcuda:
1276+
input_type = _import_cvcuda().Tensor
12661277
check_functional_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type)
12671278

12681279
@pytest.mark.parametrize(
@@ -1271,6 +1282,10 @@ def test_functional_signature(self, kernel, input_type):
12711282
make_image_tensor,
12721283
make_image_pil,
12731284
make_image,
1285+
pytest.param(
1286+
make_image_cvcuda,
1287+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1288+
),
12741289
make_bounding_boxes,
12751290
make_segmentation_mask,
12761291
make_video,
@@ -1284,13 +1299,23 @@ def test_transform(self, make_input, device):
12841299
@pytest.mark.parametrize(
12851300
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
12861301
)
1287-
def test_image_correctness(self, fn):
1288-
image = make_image(dtype=torch.uint8, device="cpu")
1289-
1302+
@pytest.mark.parametrize(
1303+
"make_input",
1304+
[
1305+
make_image,
1306+
pytest.param(
1307+
make_image_cvcuda,
1308+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1309+
),
1310+
],
1311+
)
1312+
def test_image_correctness(self, fn, make_input):
1313+
image = make_input()
12901314
actual = fn(image)
1291-
expected = F.to_image(F.horizontal_flip(F.to_pil_image(image)))
1292-
1293-
torch.testing.assert_close(actual, expected)
1315+
if make_input is make_image_cvcuda:
1316+
image = F.cvcuda_to_tensor(image)[0].cpu()
1317+
expected = F.horizontal_flip(F.to_pil_image(image))
1318+
assert_equal(actual, expected)
12941319

12951320
def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
12961321
affine_matrix = np.array(
@@ -1346,6 +1371,10 @@ def test_keypoints_correctness(self, fn):
13461371
make_image_tensor,
13471372
make_image_pil,
13481373
make_image,
1374+
pytest.param(
1375+
make_image_cvcuda,
1376+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1377+
),
13491378
make_bounding_boxes,
13501379
make_segmentation_mask,
13511380
make_video,
@@ -1355,11 +1384,8 @@ def test_keypoints_correctness(self, fn):
13551384
@pytest.mark.parametrize("device", cpu_and_cuda())
13561385
def test_transform_noop(self, make_input, device):
13571386
input = make_input(device=device)
1358-
13591387
transform = transforms.RandomHorizontalFlip(p=0)
1360-
13611388
output = transform(input)
1362-
13631389
assert_equal(output, input)
13641390

13651391

@@ -1857,6 +1883,10 @@ def test_kernel_video(self):
18571883
make_image_tensor,
18581884
make_image_pil,
18591885
make_image,
1886+
pytest.param(
1887+
make_image_cvcuda,
1888+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1889+
),
18601890
make_bounding_boxes,
18611891
make_segmentation_mask,
18621892
make_video,
@@ -1872,13 +1902,20 @@ def test_functional(self, make_input):
18721902
(F.vertical_flip_image, torch.Tensor),
18731903
(F._geometry._vertical_flip_image_pil, PIL.Image.Image),
18741904
(F.vertical_flip_image, tv_tensors.Image),
1905+
pytest.param(
1906+
F._geometry._vertical_flip_image_cvcuda,
1907+
None,
1908+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1909+
),
18751910
(F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
18761911
(F.vertical_flip_mask, tv_tensors.Mask),
18771912
(F.vertical_flip_video, tv_tensors.Video),
18781913
(F.vertical_flip_keypoints, tv_tensors.KeyPoints),
18791914
],
18801915
)
18811916
def test_functional_signature(self, kernel, input_type):
1917+
if kernel is F._geometry._vertical_flip_image_cvcuda:
1918+
input_type = _import_cvcuda().Tensor
18821919
check_functional_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type)
18831920

18841921
@pytest.mark.parametrize(
@@ -1887,6 +1924,10 @@ def test_functional_signature(self, kernel, input_type):
18871924
make_image_tensor,
18881925
make_image_pil,
18891926
make_image,
1927+
pytest.param(
1928+
make_image_cvcuda,
1929+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1930+
),
18901931
make_bounding_boxes,
18911932
make_segmentation_mask,
18921933
make_video,
@@ -1898,13 +1939,23 @@ def test_transform(self, make_input, device):
18981939
check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device))
18991940

19001941
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
1901-
def test_image_correctness(self, fn):
1902-
image = make_image(dtype=torch.uint8, device="cpu")
1903-
1942+
@pytest.mark.parametrize(
1943+
"make_input",
1944+
[
1945+
make_image,
1946+
pytest.param(
1947+
make_image_cvcuda,
1948+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1949+
),
1950+
],
1951+
)
1952+
def test_image_correctness(self, fn, make_input):
1953+
image = make_input()
19041954
actual = fn(image)
1905-
expected = F.to_image(F.vertical_flip(F.to_pil_image(image)))
1906-
1907-
torch.testing.assert_close(actual, expected)
1955+
if make_input is make_image_cvcuda:
1956+
image = F.cvcuda_to_tensor(image)[0].cpu()
1957+
expected = F.vertical_flip(F.to_pil_image(image))
1958+
assert_equal(actual, expected)
19081959

19091960
def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
19101961
affine_matrix = np.array(
@@ -1956,6 +2007,10 @@ def test_keypoints_correctness(self, fn):
19562007
make_image_tensor,
19572008
make_image_pil,
19582009
make_image,
2010+
pytest.param(
2011+
make_image_cvcuda,
2012+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
2013+
),
19592014
make_bounding_boxes,
19602015
make_segmentation_mask,
19612016
make_video,
@@ -1965,11 +2020,8 @@ def test_keypoints_correctness(self, fn):
19652020
@pytest.mark.parametrize("device", cpu_and_cuda())
19662021
def test_transform_noop(self, make_input, device):
19672022
input = make_input(device=device)
1968-
19692023
transform = transforms.RandomVerticalFlip(p=0)
1970-
19712024
output = transform(input)
1972-
19732025
assert_equal(output, input)
19742026

19752027

@@ -6826,7 +6878,7 @@ def test_functional_and_transform(self, dtype, device, color_space, batch_dims,
68266878
assert F.get_size(output) == F.get_size(input_tensor)
68276879

68286880
def test_functional_error(self):
6829-
with pytest.raises(TypeError, match="cvcuda_img should be `cvcuda.Tensor`"):
6881+
with pytest.raises(TypeError, match=r"cvcuda_img should be ``cvcuda\.Tensor``\. Got .+\."):
68306882
F.cvcuda_to_tensor(object())
68316883

68326884

torchvision/transforms/v2/_geometry.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchvision.ops.boxes import box_iou
1212
from torchvision.transforms.functional import _get_perspective_coeffs
1313
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
14-
from torchvision.transforms.v2.functional._utils import _FillType
14+
from torchvision.transforms.v2.functional._utils import _FillType, _is_cvcuda_available, _is_cvcuda_tensor
1515

1616
from ._transform import _RandomApplyTransform
1717
from ._utils import (
@@ -30,6 +30,8 @@
3030
query_size,
3131
)
3232

33+
CVCUDA_AVAILABLE = _is_cvcuda_available()
34+
3335

3436
class RandomHorizontalFlip(_RandomApplyTransform):
3537
"""Horizontally flip the input with a given probability.
@@ -45,6 +47,9 @@ class RandomHorizontalFlip(_RandomApplyTransform):
4547

4648
_v1_transform_cls = _transforms.RandomHorizontalFlip
4749

50+
if CVCUDA_AVAILABLE:
51+
_transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,)
52+
4853
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
4954
return self._call_kernel(F.horizontal_flip, inpt)
5055

@@ -63,6 +68,9 @@ class RandomVerticalFlip(_RandomApplyTransform):
6368

6469
_v1_transform_cls = _transforms.RandomVerticalFlip
6570

71+
if CVCUDA_AVAILABLE:
72+
_transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,)
73+
6674
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
6775
return self._call_kernel(F.vertical_flip, inpt)
6876

0 commit comments

Comments
 (0)