Skip to content

Commit 184e379

Browse files
committed
simplify normalize testing into single test parameterize on input creation
1 parent 57ca083 commit 184e379

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

test/test_transforms_v2.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5650,24 +5650,36 @@ def _reference_normalize_image(self, image, *, mean, std):
56505650

56515651
@pytest.mark.parametrize(("mean", "std"), MEANS_STDS)
56525652
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64])
5653+
@pytest.mark.parametrize(
5654+
"make_input",
5655+
[
5656+
make_image,
5657+
pytest.param(
5658+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5659+
),
5660+
],
5661+
)
56535662
@pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)])
5654-
def test_correctness_image(self, mean, std, dtype, fn):
5655-
image = make_image(dtype=dtype)
5663+
def test_correctness_image(self, mean, std, dtype, make_input, fn):
5664+
if make_input == make_image_cvcuda and dtype != torch.float32:
5665+
pytest.skip("CVCUDA only supports float32 for normalize")
5666+
5667+
image = make_input(dtype=dtype)
56565668

56575669
actual = fn(image, mean=mean, std=std)
5658-
expected = self._reference_normalize_image(image, mean=mean, std=std)
56595670

5660-
assert_equal(actual, expected)
5671+
if make_input == make_image_cvcuda:
5672+
image = F.cvcuda_to_tensor(image).to(device="cpu")
5673+
image = image.squeeze(0)
5674+
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
5675+
actual = actual.squeeze(0)
56615676

5662-
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5663-
@pytest.mark.parametrize(("mean", "std"), MEANS_STDS)
5664-
@pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)])
5665-
def test_correctness_cvcuda(self, mean, std, fn):
5666-
image = make_image(batch_dims=(1,), dtype=torch.float32, device="cuda")
5667-
cvc_image = F.to_cvcuda_tensor(image)
5668-
actual = F._misc._normalize_cvcuda(cvc_image, mean=mean, std=std)
5669-
expected = fn(image, mean=mean, std=std)
5670-
torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=1e-7, atol=1e-7)
5677+
expected = self._reference_normalize_image(image, mean=mean, std=std)
5678+
5679+
if make_input == make_image_cvcuda:
5680+
torch.testing.assert_close(actual, expected, rtol=0, atol=1e-6)
5681+
else:
5682+
assert_equal(actual, expected)
56715683

56725684

56735685
class TestClampBoundingBoxes:

torchvision/transforms/v2/_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch import nn
99
from torch.utils._pytree import tree_flatten, tree_unflatten
1010
from torchvision import tv_tensors
11-
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
11+
from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor
1212
from torchvision.utils import _log_api_usage_once
1313

1414
from .functional._utils import _get_kernel
@@ -23,7 +23,7 @@ class Transform(nn.Module):
2323

2424
# Class attribute defining transformed types. Other types are passed-through without any transformation
2525
# We support both Types and callables that are able to do further checks on the type of the input.
26-
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
26+
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)
2727

2828
def __init__(self) -> None:
2929
super().__init__()

0 commit comments

Comments
 (0)