Skip to content

Commit 7ccc301

Browse files
committed
implement invert cvcuda
1 parent fbea584 commit 7ccc301

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

test/test_transforms_v2.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5825,7 +5825,18 @@ def test_kernel_image(self, dtype, device):
58255825
def test_kernel_video(self):
58265826
check_kernel(F.invert_video, make_video())
58275827

5828-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
5828+
@pytest.mark.parametrize(
5829+
"make_input",
5830+
[
5831+
make_image_tensor,
5832+
make_image,
5833+
make_image_pil,
5834+
make_video,
5835+
pytest.param(
5836+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5837+
),
5838+
],
5839+
)
58295840
def test_functional(self, make_input):
58305841
check_functional(F.invert, make_input())
58315842

@@ -5836,12 +5847,30 @@ def test_functional(self, make_input):
58365847
(F._color._invert_image_pil, PIL.Image.Image),
58375848
(F.invert_image, tv_tensors.Image),
58385849
(F.invert_video, tv_tensors.Video),
5850+
pytest.param(
5851+
F._color._invert_cvcuda,
5852+
"cvcuda.Tensor",
5853+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
5854+
),
58395855
],
58405856
)
58415857
def test_functional_signature(self, kernel, input_type):
5858+
if input_type == "cvcuda.Tensor":
5859+
input_type = _import_cvcuda().Tensor
58425860
check_functional_kernel_signature_match(F.invert, kernel=kernel, input_type=input_type)
58435861

5844-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
5862+
@pytest.mark.parametrize(
5863+
"make_input",
5864+
[
5865+
make_image_tensor,
5866+
make_image_pil,
5867+
make_image,
5868+
make_video,
5869+
pytest.param(
5870+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5871+
),
5872+
],
5873+
)
58455874
def test_transform(self, make_input):
58465875
check_transform(transforms.RandomInvert(p=1), make_input())
58475876

@@ -5854,6 +5883,16 @@ def test_correctness_image(self, fn):
58545883

58555884
assert_equal(actual, expected)
58565885

5886+
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5887+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
5888+
@pytest.mark.parametrize("fn", [F.invert, transform_cls_to_functional(transforms.RandomInvert, p=1)])
5889+
def test_correctness_cvcuda(self, dtype, fn):
5890+
image = make_image(batch_dims=(1,), dtype=dtype, device="cuda")
5891+
cv_image = F.to_cvcuda_tensor(image)
5892+
actual = F.cvcuda_to_tensor(fn(cv_image))
5893+
expected = F.invert_image(image)
5894+
assert_equal(actual, expected)
5895+
58575896

58585897
class TestPosterize:
58595898
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])

torchvision/transforms/v2/functional/_color.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,42 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
690690
return invert_image(video)
691691

692692

693+
if _CVCUDA_AVAILABLE:
694+
_invert_cvcuda_tensors = {}
695+
696+
697+
def _invert_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
698+
cvcuda = _import_cvcuda()
699+
700+
if "base" not in _invert_cvcuda_tensors:
701+
_invert_cvcuda_tensors["base"] = cvcuda.as_tensor(
702+
torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device="cuda").reshape(1, 1, 1, 3).contiguous(), "NHWC"
703+
)
704+
if "scale" not in _invert_cvcuda_tensors:
705+
_invert_cvcuda_tensors["scale"] = cvcuda.as_tensor(
706+
torch.tensor([-1.0, -1.0, -1.0], dtype=torch.float32, device="cuda").reshape(1, 1, 1, 3).contiguous(),
707+
"NHWC",
708+
)
709+
710+
base = _invert_cvcuda_tensors["base"]
711+
scale = _invert_cvcuda_tensors["scale"]
712+
713+
if image.dtype == cvcuda.Type.U8:
714+
shift = 255.0
715+
elif image.dtype == cvcuda.Type.F32:
716+
shift = 1.0
717+
else:
718+
raise ValueError(f"Input image dtype must be uint8 or float32, got {image.dtype}")
719+
720+
# Use normalize to invert: output = (input - base) * scale * global_scale + shift
721+
# For inversion: output = (input - 0) * (-1) * 1 + shift = shift - input
722+
return cvcuda.normalize(image, base=base, scale=scale, globalscale=1.0, globalshift=shift)
723+
724+
725+
if _CVCUDA_AVAILABLE:
726+
_register_kernel_internal(invert, _import_cvcuda().Tensor)(_invert_cvcuda)
727+
728+
693729
def permute_channels(inpt: torch.Tensor, permutation: list[int]) -> torch.Tensor:
694730
"""Permute the channels of the input according to the given permutation.
695731

0 commit comments

Comments
 (0)