Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes
from torchvision.transforms.v2.functional._utils import (
_cvcuda_shared_stream,
_get_kernel,
_import_cvcuda,
_is_cvcuda_available,
Expand Down Expand Up @@ -8075,3 +8076,29 @@ def test_different_sizes(self, make_input1, make_input2, query):
def test_no_valid_input(self, query):
with pytest.raises(TypeError, match="No image"):
query(["blah"])


@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
@needs_cuda
class TestCVCUDASharedStream:
if _is_cvcuda_available():
cvcuda = _import_cvcuda()

def test_shared_stream(self):
stream = torch.cuda.Stream(device=None)

@_cvcuda_shared_stream
def _assert_cvcuda_shared_stream():
assert self.cvcuda.Stream.current.handle == stream.cuda_stream

with stream:
_assert_cvcuda_shared_stream()

def test_shared_stream_negative(self):
stream = torch.cuda.Stream(device=None)

def _assert_cvcuda_shared_stream_negative():
assert self.cvcuda.Stream.current.handle != stream.cuda_stream

with stream:
_assert_cvcuda_shared_stream_negative()
9 changes: 7 additions & 2 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format

from ._utils import (
_cvcuda_shared_stream,
_FillTypeJIT,
_get_kernel,
_import_cvcuda,
Expand Down Expand Up @@ -78,7 +79,9 @@ def _horizontal_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":


if CVCUDA_AVAILABLE:
_register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_image_cvcuda)
_register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(
_cvcuda_shared_stream(_horizontal_flip_image_cvcuda)
)


@_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
Expand Down Expand Up @@ -174,7 +177,9 @@ def _vertical_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":


if CVCUDA_AVAILABLE:
_register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_image_cvcuda)
_register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(
_cvcuda_shared_stream(_vertical_flip_image_cvcuda)
)


@_register_kernel_internal(vertical_flip, tv_tensors.Mask)
Expand Down
20 changes: 19 additions & 1 deletion torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import functools
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, ParamSpec, TypeVar, Union

import torch
from torchvision import tv_tensors

P = ParamSpec("P")
R = TypeVar("R")

_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[list[float]]

Expand Down Expand Up @@ -177,3 +180,18 @@ def _is_cvcuda_tensor(inpt: Any) -> bool:
return isinstance(inpt, cvcuda.Tensor)
except ImportError:
return False


def _cvcuda_shared_stream(fn: Callable[P, R]) -> Callable[P, R]:
cvcuda = _import_cvcuda()

@functools.wraps(fn)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
stream = torch.cuda.current_stream()

with cvcuda.as_stream(stream):
result = fn(*args, **kwargs)

return result

return wrapper