Skip to content

Commit 1e864d8

Browse files
committed
initial cvcuda normalize kernel implementation
1 parent 4939355 commit 1e864d8

File tree

3 files changed

+109
-0
lines changed

3 files changed

+109
-0
lines changed

test/test_transforms_v2.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5633,6 +5633,79 @@ def test_correctness_image(self, mean, std, dtype, fn):
56335633
assert_equal(actual, expected)
56345634

56355635

5636+
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5637+
@needs_cuda
5638+
class TestNormalizeCVCUDA:
5639+
MEANS_STDS = {
5640+
"RGB": TestNormalize.MEANS_STDS,
5641+
"GRAY": [([0.5], [2.0])],
5642+
}
5643+
MEAN_STD = {
5644+
"RGB": MEANS_STDS["RGB"][0],
5645+
"GRAY": MEANS_STDS["GRAY"][0],
5646+
}
5647+
5648+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32])
5649+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5650+
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
5651+
def test_functional(self, color_space, batch_dims, dtype):
5652+
means_stds = self.MEANS_STDS[color_space]
5653+
for mean, std in means_stds:
5654+
image = make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims)
5655+
check_functional(F.normalize, image, mean=mean, std=std)
5656+
5657+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32])
5658+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5659+
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
5660+
def test_functional_scalar(self, color_space, batch_dims, dtype):
5661+
image = make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims)
5662+
check_functional(F.normalize, image, mean=0.5, std=2.0)
5663+
5664+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32])
5665+
@pytest.mark.parametrize("batch_dims", [(1,)])
5666+
def test_functional_error(self, dtype, batch_dims):
5667+
rgb_mean, rgb_std = self.MEAN_STD["RGB"]
5668+
gray_mean, gray_std = self.MEAN_STD["GRAY"]
5669+
5670+
with pytest.raises(ValueError, match="Inplace normalization is not supported for CVCUDA."):
5671+
F.normalize(make_image_cvcuda(batch_dims=batch_dims, dtype=dtype), mean=rgb_mean, std=rgb_std, inplace=True)
5672+
5673+
with pytest.raises(ValueError, match="Mean should have 3 elements. Got 1."):
5674+
F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="RGB", dtype=dtype), mean=gray_mean, std=rgb_std)
5675+
5676+
with pytest.raises(ValueError, match="Std should have 3 elements. Got 1."):
5677+
F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="RGB", dtype=dtype), mean=rgb_mean, std=gray_std)
5678+
5679+
with pytest.raises(ValueError, match="Mean should have 1 elements. Got 3."):
5680+
F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="GRAY", dtype=dtype), mean=rgb_mean, std=gray_std)
5681+
5682+
with pytest.raises(ValueError, match="Std should have 1 elements. Got 3."):
5683+
F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="GRAY", dtype=dtype), mean=gray_mean, std=rgb_std)
5684+
5685+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32])
5686+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5687+
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
5688+
def test_transform(self, dtype, color_space, batch_dims):
5689+
means_stds = self.MEANS_STDS[color_space]
5690+
for mean, std in means_stds:
5691+
check_transform(
5692+
transforms.Normalize(mean=mean, std=std),
5693+
make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims),
5694+
)
5695+
5696+
@pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)])
5697+
def test_correctness_image(self, batch_dims):
5698+
mean, std = self.MEAN_STD["RGB"]
5699+
torch_image = make_image(batch_dims=batch_dims, dtype=torch.float32, device="cuda")
5700+
cvc_image = F.to_cvcuda_tensor(torch_image)
5701+
5702+
gold = F.normalize(torch_image, mean=mean, std=std)
5703+
image = F.normalize(cvc_image, mean=mean, std=std)
5704+
image = F.cvcuda_to_tensor(image)
5705+
5706+
assert_close(image, gold, rtol=1e-7, atol=1e-7)
5707+
5708+
56365709
class TestClampBoundingBoxes:
56375710
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
56385711
@pytest.mark.parametrize("clamping_mode", ("soft", "hard", None))

torchvision/transforms/v2/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
gaussian_noise_image,
154154
gaussian_noise_video,
155155
normalize,
156+
normalize_cvcuda,
156157
normalize_image,
157158
normalize_video,
158159
sanitize_bounding_boxes,

torchvision/transforms/v2/functional/_misc.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,41 @@ def normalize_video(video: torch.Tensor, mean: list[float], std: list[float], in
7979
return normalize_image(video, mean, std, inplace=inplace)
8080

8181

82+
def normalize_cvcuda(
83+
image: "cvcuda.Tensor",
84+
mean: Sequence[float | int] | float | int,
85+
std: Sequence[float | int] | float | int,
86+
inplace: bool = False,
87+
) -> "cvcuda.Tensor":
88+
if inplace:
89+
raise ValueError("Inplace normalization is not supported for CVCUDA.")
90+
91+
channels = image.shape[3]
92+
if isinstance(mean, float | int):
93+
mean = [mean] * channels
94+
elif len(mean) != channels:
95+
raise ValueError(f"Mean should have {channels} elements. Got {len(mean)}.")
96+
if isinstance(std, float | int):
97+
std = [std] * channels
98+
elif len(std) != channels:
99+
raise ValueError(f"Std should have {channels} elements. Got {len(std)}.")
100+
101+
mean = torch.as_tensor(mean, dtype=torch.float32)
102+
std = torch.as_tensor(std, dtype=torch.float32)
103+
mean_tensor = mean.reshape(1, 1, 1, channels)
104+
std_tensor = std.reshape(1, 1, 1, channels)
105+
mean_tensor = mean_tensor.cuda()
106+
std_tensor = std_tensor.cuda()
107+
mean_cv = cvcuda.as_tensor(mean_tensor, cvcuda.TensorLayout.NHWC)
108+
std_cv = cvcuda.as_tensor(std_tensor, cvcuda.TensorLayout.NHWC)
109+
110+
return cvcuda.normalize(image, base=mean_cv, scale=std_cv, flags=cvcuda.NormalizeFlags.SCALE_IS_STDDEV)
111+
112+
113+
if CVCUDA_AVAILABLE:
114+
_normalize_cvcuda = _register_kernel_internal(normalize, cvcuda.Tensor)(normalize_cvcuda)
115+
116+
82117
def gaussian_blur(inpt: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> torch.Tensor:
83118
"""See :class:`~torchvision.transforms.v2.GaussianBlur` for details."""
84119
if torch.jit.is_scripting():

0 commit comments

Comments
 (0)