Skip to content

Commit d1a4eff

Browse files
committed
update invert with new PR comments revisions
1 parent 8908ca5 commit d1a4eff

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

test/test_transforms_v2.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
assert_equal,
2626
cache,
2727
cpu_and_cuda,
28+
cvcuda_to_pil_compatible_tensor,
2829
freeze_rng_state,
2930
ignore_jit_no_profile_information_warning,
3031
make_bounding_boxes,
@@ -5821,23 +5822,26 @@ def test_functional_signature(self, kernel, input_type):
58215822
def test_transform(self, make_input):
58225823
check_transform(transforms.RandomInvert(p=1), make_input())
58235824

5825+
@pytest.mark.parametrize(
5826+
"make_input",
5827+
[
5828+
make_image,
5829+
pytest.param(
5830+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5831+
),
5832+
],
5833+
)
58245834
@pytest.mark.parametrize("fn", [F.invert, transform_cls_to_functional(transforms.RandomInvert, p=1)])
5825-
def test_correctness_image(self, fn):
5826-
image = make_image(dtype=torch.uint8, device="cpu")
5835+
def test_correctness_image(self, make_input, fn):
5836+
image = make_input(dtype=torch.uint8, device="cpu")
58275837

58285838
actual = fn(image)
5829-
expected = F.to_image(F.invert(F.to_pil_image(image)))
58305839

5831-
assert_equal(actual, expected)
5840+
if make_input is make_image_cvcuda:
5841+
image = cvcuda_to_pil_compatible_tensor(image)
5842+
5843+
expected = F.to_image(F.invert(F.to_pil_image(image)))
58325844

5833-
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5834-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
5835-
@pytest.mark.parametrize("fn", [F.invert, transform_cls_to_functional(transforms.RandomInvert, p=1)])
5836-
def test_correctness_cvcuda(self, dtype, fn):
5837-
image = make_image(batch_dims=(1,), dtype=dtype, device="cuda")
5838-
cv_image = F.to_cvcuda_tensor(image)
5839-
actual = F.cvcuda_to_tensor(fn(cv_image))
5840-
expected = F.invert_image(image)
58415845
assert_equal(actual, expected)
58425846

58435847

torchvision/transforms/v2/functional/_color.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,13 +690,15 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
690690
return invert_image(video)
691691

692692

693-
if _CVCUDA_AVAILABLE:
693+
if CVCUDA_AVAILABLE:
694694
_invert_cvcuda_tensors = {}
695695

696696

697697
def _invert_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
698698
cvcuda = _import_cvcuda()
699699

700+
# save the tensors into a dictionary only if CV-CUDA is actually used
701+
# we save these here, since they are static and small in size
700702
if "base" not in _invert_cvcuda_tensors:
701703
_invert_cvcuda_tensors["base"] = cvcuda.as_tensor(
702704
torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device="cuda").reshape(1, 1, 1, 3).contiguous(), "NHWC"
@@ -722,7 +724,7 @@ def _invert_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
722724
return cvcuda.normalize(image, base=base, scale=scale, globalscale=1.0, globalshift=shift)
723725

724726

725-
if _CVCUDA_AVAILABLE:
727+
if CVCUDA_AVAILABLE:
726728
_register_kernel_internal(invert, _import_cvcuda().Tensor)(_invert_cvcuda)
727729

728730

0 commit comments

Comments
 (0)