Skip to content

Commit 79ea0da

Browse files
committed
fix: normalize_cvcuda move to correct patterns for tests/exporting
1 parent 01efae7 commit 79ea0da

File tree

3 files changed

+48
-79
lines changed

3 files changed

+48
-79
lines changed

test/test_transforms_v2.py

Lines changed: 37 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -5570,7 +5570,17 @@ def test_kernel_image_inplace(self, device):
55705570
def test_kernel_video(self):
55715571
check_kernel(F.normalize_video, make_video(dtype=torch.float32), mean=self.MEAN, std=self.STD)
55725572

5573-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
5573+
@pytest.mark.parametrize(
5574+
"make_input",
5575+
[
5576+
make_image_tensor,
5577+
make_image,
5578+
make_video,
5579+
pytest.param(
5580+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5581+
),
5582+
],
5583+
)
55745584
def test_functional(self, make_input):
55755585
check_functional(F.normalize, make_input(dtype=torch.float32), mean=self.MEAN, std=self.STD)
55765586

@@ -5580,6 +5590,11 @@ def test_functional(self, make_input):
55805590
(F.normalize_image, torch.Tensor),
55815591
(F.normalize_image, tv_tensors.Image),
55825592
(F.normalize_video, tv_tensors.Video),
5593+
pytest.param(
5594+
F._misc._normalize_cvcuda,
5595+
_import_cvcuda().Tensor,
5596+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
5597+
),
55835598
],
55845599
)
55855600
def test_functional_signature(self, kernel, input_type):
@@ -5608,7 +5623,17 @@ def _sample_input_adapter(self, transform, input, device):
56085623
adapted_input[key] = value
56095624
return adapted_input
56105625

5611-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
5626+
@pytest.mark.parametrize(
5627+
"make_input",
5628+
[
5629+
make_image_tensor,
5630+
make_image,
5631+
make_video,
5632+
pytest.param(
5633+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5634+
),
5635+
],
5636+
)
56125637
def test_transform(self, make_input):
56135638
check_transform(
56145639
transforms.Normalize(mean=self.MEAN, std=self.STD),
@@ -5632,78 +5657,16 @@ def test_correctness_image(self, mean, std, dtype, fn):
56325657

56335658
assert_equal(actual, expected)
56345659

5635-
5636-
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5637-
@needs_cuda
5638-
class TestNormalizeCVCUDA:
5639-
MEANS_STDS = {
5640-
"RGB": TestNormalize.MEANS_STDS,
5641-
"GRAY": [([0.5], [2.0])],
5642-
}
5643-
MEAN_STD = {
5644-
"RGB": MEANS_STDS["RGB"][0],
5645-
"GRAY": MEANS_STDS["GRAY"][0],
5646-
}
5647-
5648-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32])
5649-
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5650-
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
5651-
def test_functional(self, color_space, batch_dims, dtype):
5652-
means_stds = self.MEANS_STDS[color_space]
5653-
for mean, std in means_stds:
5654-
image = make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims)
5655-
check_functional(F.normalize, image, mean=mean, std=std)
5656-
5657-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32])
5658-
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5659-
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
5660-
def test_functional_scalar(self, color_space, batch_dims, dtype):
5661-
image = make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims)
5662-
check_functional(F.normalize, image, mean=0.5, std=2.0)
5663-
5664-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32])
5665-
@pytest.mark.parametrize("batch_dims", [(1,)])
5666-
def test_functional_error(self, dtype, batch_dims):
5667-
rgb_mean, rgb_std = self.MEAN_STD["RGB"]
5668-
gray_mean, gray_std = self.MEAN_STD["GRAY"]
5669-
5670-
with pytest.raises(ValueError, match="Inplace normalization is not supported for CVCUDA."):
5671-
F.normalize(make_image_cvcuda(batch_dims=batch_dims, dtype=dtype), mean=rgb_mean, std=rgb_std, inplace=True)
5672-
5673-
with pytest.raises(ValueError, match="Mean should have 3 elements. Got 1."):
5674-
F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="RGB", dtype=dtype), mean=gray_mean, std=rgb_std)
5675-
5676-
with pytest.raises(ValueError, match="Std should have 3 elements. Got 1."):
5677-
F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="RGB", dtype=dtype), mean=rgb_mean, std=gray_std)
5678-
5679-
with pytest.raises(ValueError, match="Mean should have 1 elements. Got 3."):
5680-
F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="GRAY", dtype=dtype), mean=rgb_mean, std=gray_std)
5681-
5682-
with pytest.raises(ValueError, match="Std should have 1 elements. Got 3."):
5683-
F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="GRAY", dtype=dtype), mean=gray_mean, std=rgb_std)
5684-
5685-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32])
5686-
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5687-
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
5688-
def test_transform(self, dtype, color_space, batch_dims):
5689-
means_stds = self.MEANS_STDS[color_space]
5690-
for mean, std in means_stds:
5691-
check_transform(
5692-
transforms.Normalize(mean=mean, std=std),
5693-
make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims),
5694-
)
5695-
5696-
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
5697-
def test_correctness_image(self, batch_dims):
5698-
mean, std = self.MEAN_STD["RGB"]
5699-
torch_image = make_image(batch_dims=batch_dims, dtype=torch.float32, device="cuda")
5700-
cvc_image = F.to_cvcuda_tensor(torch_image)
5701-
5702-
gold = F.normalize(torch_image, mean=mean, std=std)
5703-
image = F.normalize(cvc_image, mean=mean, std=std)
5704-
image = F.cvcuda_to_tensor(image)
5705-
5706-
assert_close(image, gold, rtol=1e-7, atol=1e-7)
5660+
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5661+
@pytest.mark.parametrize(("mean", "std"), MEANS_STDS)
5662+
@pytest.mark.parametrize("dtype", [torch.float32])
5663+
@pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)])
5664+
def test_correctness_cvcuda(self, mean, std, dtype, fn):
5665+
image = make_image(batch_dims=(1,), dtype=dtype, device="cuda")
5666+
cvc_image = F.to_cvcuda_tensor(image)
5667+
actual = F._misc._normalize_cvcuda(cvc_image, mean=mean, std=std)
5668+
expected = fn(image, mean=mean, std=std)
5669+
torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=1e-7, atol=1e-7)
57075670

57085671

57095672
class TestClampBoundingBoxes:

torchvision/transforms/v2/functional/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@
153153
gaussian_noise_image,
154154
gaussian_noise_video,
155155
normalize,
156-
normalize_cvcuda,
157156
normalize_image,
158157
normalize_video,
159158
sanitize_bounding_boxes,

torchvision/transforms/v2/functional/_misc.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,22 @@ def normalize_video(video: torch.Tensor, mean: list[float], std: list[float], in
7979
return normalize_image(video, mean, std, inplace=inplace)
8080

8181

82-
def normalize_cvcuda(
82+
def _normalize_cvcuda(
8383
image: "cvcuda.Tensor",
84-
mean: Sequence[float | int] | float | int,
85-
std: Sequence[float | int] | float | int,
84+
mean: list[float],
85+
std: list[float],
8686
inplace: bool = False,
8787
) -> "cvcuda.Tensor":
88+
cvcuda = _import_cvcuda()
8889
if inplace:
8990
raise ValueError("Inplace normalization is not supported for CVCUDA.")
9091

92+
# CV-CUDA supports signed int and float tensors
93+
# torchvision only supports uint and float, right now CV-CUDA doesnt expose float16, so only check 32
94+
# in the future add float16 once exposed in CV-CUDA
95+
if not (image.dtype == cvcuda.Type.F32):
96+
raise ValueError(f"Input tensor should be a float tensor. Got {image.dtype}.")
97+
9198
channels = image.shape[3]
9299
if isinstance(mean, float | int):
93100
mean = [mean] * channels
@@ -115,7 +122,7 @@ def normalize_cvcuda(
115122

116123

117124
if CVCUDA_AVAILABLE:
118-
_normalize_cvcuda = _register_kernel_internal(normalize, cvcuda.Tensor)(normalize_cvcuda)
125+
_normalize_cvcuda_registered = _register_kernel_internal(normalize, _import_cvcuda().Tensor)(_normalize_cvcuda)
119126

120127

121128
def gaussian_blur(inpt: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> torch.Tensor:

0 commit comments

Comments
 (0)