diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3ce603c3ed2..34bd1633856 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -53,6 +53,7 @@ from torchvision.transforms.v2._utils import check_type, is_pure_tensor from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes from torchvision.transforms.v2.functional._utils import ( + _cvcuda_shared_stream, _get_kernel, _import_cvcuda, _is_cvcuda_available, @@ -8075,3 +8076,29 @@ def test_different_sizes(self, make_input1, make_input2, query): def test_no_valid_input(self, query): with pytest.raises(TypeError, match="No image"): query(["blah"]) + + +@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") +@needs_cuda +class TestCVCUDASharedStream: + if _is_cvcuda_available(): + cvcuda = _import_cvcuda() + + def test_shared_stream(self): + stream = torch.cuda.Stream(device=None) + + @_cvcuda_shared_stream + def _assert_cvcuda_shared_stream(): + assert self.cvcuda.Stream.current.handle == stream.cuda_stream + + with stream: + _assert_cvcuda_shared_stream() + + def test_shared_stream_negative(self): + stream = torch.cuda.Stream(device=None) + + def _assert_cvcuda_shared_stream_negative(): + assert self.cvcuda.Stream.current.handle != stream.cuda_stream + + with stream: + _assert_cvcuda_shared_stream_negative() diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 0e27218bc89..49755e2d7cd 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -27,6 +27,7 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format from ._utils import ( + _cvcuda_shared_stream, _FillTypeJIT, _get_kernel, _import_cvcuda, @@ -78,7 +79,9 @@ def _horizontal_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": if CVCUDA_AVAILABLE: - _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_image_cvcuda) + _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)( + _cvcuda_shared_stream(_horizontal_flip_image_cvcuda) + ) @_register_kernel_internal(horizontal_flip, tv_tensors.Mask) @@ -174,7 +177,9 @@ def _vertical_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": if CVCUDA_AVAILABLE: - _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_image_cvcuda) + _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)( + _cvcuda_shared_stream(_vertical_flip_image_cvcuda) + ) @_register_kernel_internal(vertical_flip, tv_tensors.Mask) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 11480b30ef9..2eb3c0741a4 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,10 +1,13 @@ import functools from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, ParamSpec, TypeVar, Union import torch from torchvision import tv_tensors +P = ParamSpec("P") +R = TypeVar("R") + _FillType = Union[int, float, Sequence[int], Sequence[float], None] _FillTypeJIT = Optional[list[float]] @@ -177,3 +180,18 @@ def _is_cvcuda_tensor(inpt: Any) -> bool: return isinstance(inpt, cvcuda.Tensor) except ImportError: return False + + +def _cvcuda_shared_stream(fn: Callable[P, R]) -> Callable[P, R]: + cvcuda = _import_cvcuda() + + @functools.wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + stream = torch.cuda.current_stream() + + with cvcuda.as_stream(stream): + result = fn(*args, **kwargs) + + return result + + return wrapper