Skip to content

Commit 7d9c0f4

Browse files
committed
begin editing for main changes
1 parent a3d8797 commit 7d9c0f4

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

test/test_transforms_v2.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
assert_equal,
2626
cache,
2727
cpu_and_cuda,
28-
cvcuda_to_pil_compatible_tensor,
2928
freeze_rng_state,
3029
ignore_jit_no_profile_information_warning,
3130
make_bounding_boxes,
@@ -5848,14 +5847,14 @@ def test_functional(self, make_input):
58485847
(F.invert_image, tv_tensors.Image),
58495848
(F.invert_video, tv_tensors.Video),
58505849
pytest.param(
5851-
F._color._invert_cvcuda,
5852-
"cvcuda.Tensor",
5850+
F._color._invert_image_cvcuda,
5851+
None,
58535852
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
58545853
),
58555854
],
58565855
)
58575856
def test_functional_signature(self, kernel, input_type):
5858-
if input_type == "cvcuda.Tensor":
5857+
if kernel is F._color._invert_image_cvcuda:
58595858
input_type = _import_cvcuda().Tensor
58605859
check_functional_kernel_signature_match(F.invert, kernel=kernel, input_type=input_type)
58615860

@@ -5890,7 +5889,7 @@ def test_correctness_image(self, make_input, fn):
58905889
actual = fn(image)
58915890

58925891
if make_input is make_image_cvcuda:
5893-
image = cvcuda_to_pil_compatible_tensor(image)
5892+
image = F.cvcuda_to_tensor(image)[0].cpu()
58945893

58955894
expected = F.to_image(F.invert(F.to_pil_image(image)))
58965895

torchvision/transforms/v2/_color.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
import torch
66
from torchvision import transforms as _transforms
77
from torchvision.transforms.v2 import functional as F, Transform
8+
from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor
89

910
from ._transform import _RandomApplyTransform
1011
from ._utils import query_chw
1112

1213

14+
CVCUDA_AVAILABLE = _is_cvcuda_available()
15+
16+
1317
class Grayscale(Transform):
1418
"""Convert images or videos to grayscale.
1519
@@ -282,6 +286,9 @@ class RandomInvert(_RandomApplyTransform):
282286

283287
_v1_transform_cls = _transforms.RandomInvert
284288

289+
if CVCUDA_AVAILABLE:
290+
_transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,)
291+
285292
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
286293
return self._call_kernel(F.invert, inpt)
287294

torchvision/transforms/v2/functional/_color.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -690,11 +690,10 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
690690
return invert_image(video)
691691

692692

693-
if CVCUDA_AVAILABLE:
694-
_invert_cvcuda_tensors = {}
693+
_invert_cvcuda_tensors = {}
695694

696695

697-
def _invert_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
696+
def _invert_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
698697
cvcuda = _import_cvcuda()
699698

700699
# save the tensors into a dictionary only if CV-CUDA is actually used
@@ -725,7 +724,7 @@ def _invert_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
725724

726725

727726
if CVCUDA_AVAILABLE:
728-
_register_kernel_internal(invert, _import_cvcuda().Tensor)(_invert_cvcuda)
727+
_register_kernel_internal(invert, _import_cvcuda().Tensor)(_invert_image_cvcuda)
729728

730729

731730
def permute_channels(inpt: torch.Tensor, permutation: list[int]) -> torch.Tensor:

0 commit comments

Comments
 (0)