Skip to content

Commit 8908ca5

Browse files
committed
implement invert cvcuda
1 parent 98d7dfb commit 8908ca5

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

test/test_transforms_v2.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5772,7 +5772,18 @@ def test_kernel_image(self, dtype, device):
57725772
def test_kernel_video(self):
57735773
check_kernel(F.invert_video, make_video())
57745774

5775-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
5775+
@pytest.mark.parametrize(
5776+
"make_input",
5777+
[
5778+
make_image_tensor,
5779+
make_image,
5780+
make_image_pil,
5781+
make_video,
5782+
pytest.param(
5783+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5784+
),
5785+
],
5786+
)
57765787
def test_functional(self, make_input):
57775788
check_functional(F.invert, make_input())
57785789

@@ -5783,12 +5794,30 @@ def test_functional(self, make_input):
57835794
(F._color._invert_image_pil, PIL.Image.Image),
57845795
(F.invert_image, tv_tensors.Image),
57855796
(F.invert_video, tv_tensors.Video),
5797+
pytest.param(
5798+
F._color._invert_cvcuda,
5799+
"cvcuda.Tensor",
5800+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
5801+
),
57865802
],
57875803
)
57885804
def test_functional_signature(self, kernel, input_type):
5805+
if input_type == "cvcuda.Tensor":
5806+
input_type = _import_cvcuda().Tensor
57895807
check_functional_kernel_signature_match(F.invert, kernel=kernel, input_type=input_type)
57905808

5791-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
5809+
@pytest.mark.parametrize(
5810+
"make_input",
5811+
[
5812+
make_image_tensor,
5813+
make_image_pil,
5814+
make_image,
5815+
make_video,
5816+
pytest.param(
5817+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5818+
),
5819+
],
5820+
)
57925821
def test_transform(self, make_input):
57935822
check_transform(transforms.RandomInvert(p=1), make_input())
57945823

@@ -5801,6 +5830,16 @@ def test_correctness_image(self, fn):
58015830

58025831
assert_equal(actual, expected)
58035832

5833+
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5834+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
5835+
@pytest.mark.parametrize("fn", [F.invert, transform_cls_to_functional(transforms.RandomInvert, p=1)])
5836+
def test_correctness_cvcuda(self, dtype, fn):
5837+
image = make_image(batch_dims=(1,), dtype=dtype, device="cuda")
5838+
cv_image = F.to_cvcuda_tensor(image)
5839+
actual = F.cvcuda_to_tensor(fn(cv_image))
5840+
expected = F.invert_image(image)
5841+
assert_equal(actual, expected)
5842+
58045843

58055844
class TestPosterize:
58065845
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])

torchvision/transforms/v2/functional/_color.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,42 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
690690
return invert_image(video)
691691

692692

693+
if _CVCUDA_AVAILABLE:
694+
_invert_cvcuda_tensors = {}
695+
696+
697+
def _invert_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
698+
cvcuda = _import_cvcuda()
699+
700+
if "base" not in _invert_cvcuda_tensors:
701+
_invert_cvcuda_tensors["base"] = cvcuda.as_tensor(
702+
torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device="cuda").reshape(1, 1, 1, 3).contiguous(), "NHWC"
703+
)
704+
if "scale" not in _invert_cvcuda_tensors:
705+
_invert_cvcuda_tensors["scale"] = cvcuda.as_tensor(
706+
torch.tensor([-1.0, -1.0, -1.0], dtype=torch.float32, device="cuda").reshape(1, 1, 1, 3).contiguous(),
707+
"NHWC",
708+
)
709+
710+
base = _invert_cvcuda_tensors["base"]
711+
scale = _invert_cvcuda_tensors["scale"]
712+
713+
if image.dtype == cvcuda.Type.U8:
714+
shift = 255.0
715+
elif image.dtype == cvcuda.Type.F32:
716+
shift = 1.0
717+
else:
718+
raise ValueError(f"Input image dtype must be uint8 or float32, got {image.dtype}")
719+
720+
# Use normalize to invert: output = (input - base) * scale * global_scale + shift
721+
# For inversion: output = (input - 0) * (-1) * 1 + shift = shift - input
722+
return cvcuda.normalize(image, base=base, scale=scale, globalscale=1.0, globalshift=shift)
723+
724+
725+
if _CVCUDA_AVAILABLE:
726+
_register_kernel_internal(invert, _import_cvcuda().Tensor)(_invert_cvcuda)
727+
728+
693729
def permute_channels(inpt: torch.Tensor, permutation: list[int]) -> torch.Tensor:
694730
"""Permute the channels of the input according to the given permutation.
695731

0 commit comments

Comments
 (0)