Skip to content

Commit d379658

Browse files
committed
complete and tested adjust_hue
1 parent b11c38a commit d379658

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

test/test_transforms_v2.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6201,7 +6201,18 @@ def test_kernel_image(self, dtype, device):
62016201
def test_kernel_video(self):
62026202
check_kernel(F.adjust_hue_video, make_video(), hue_factor=0.25)
62036203

6204-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
6204+
@pytest.mark.parametrize(
6205+
"make_input",
6206+
[
6207+
make_image_tensor,
6208+
make_image,
6209+
make_image_pil,
6210+
make_video,
6211+
pytest.param(
6212+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6213+
),
6214+
],
6215+
)
62056216
def test_functional(self, make_input):
62066217
check_functional(F.adjust_hue, make_input(), hue_factor=0.25)
62076218

@@ -6212,9 +6223,16 @@ def test_functional(self, make_input):
62126223
(F._color._adjust_hue_image_pil, PIL.Image.Image),
62136224
(F.adjust_hue_image, tv_tensors.Image),
62146225
(F.adjust_hue_video, tv_tensors.Video),
6226+
pytest.param(
6227+
F._color._adjust_hue_cvcuda,
6228+
"cvcuda.Tensor",
6229+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
6230+
),
62156231
],
62166232
)
62176233
def test_functional_signature(self, kernel, input_type):
6234+
if input_type == "cvcuda.Tensor":
6235+
input_type = _import_cvcuda().Tensor
62186236
check_functional_kernel_signature_match(F.adjust_hue, kernel=kernel, input_type=input_type)
62196237

62206238
def test_functional_error(self):
@@ -6225,11 +6243,27 @@ def test_functional_error(self):
62256243
with pytest.raises(ValueError, match=re.escape("is not in [-0.5, 0.5]")):
62266244
F.adjust_hue(make_image(), hue_factor=hue_factor)
62276245

6246+
@pytest.mark.parametrize(
6247+
"make_input",
6248+
[
6249+
make_image,
6250+
pytest.param(
6251+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6252+
),
6253+
],
6254+
)
62286255
@pytest.mark.parametrize("hue_factor", [-0.5, -0.3, 0.0, 0.2, 0.5])
6229-
def test_correctness_image(self, hue_factor):
6230-
image = make_image(dtype=torch.uint8, device="cpu")
6256+
def test_correctness_image(self, make_input, hue_factor):
6257+
image = make_input(dtype=torch.uint8, device="cpu")
62316258

62326259
actual = F.adjust_hue(image, hue_factor=hue_factor)
6260+
6261+
if make_input is make_image_cvcuda:
6262+
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
6263+
actual = actual.squeeze(0)
6264+
image = F.cvcuda_to_tensor(image)
6265+
image = image.squeeze(0)
6266+
62336267
expected = F.to_image(F.adjust_hue(F.to_pil_image(image), hue_factor=hue_factor))
62346268

62356269
mae = (actual.float() - expected.float()).abs().mean()

torchvision/transforms/v2/functional/_color.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,31 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
457457
return adjust_hue_image(video, hue_factor=hue_factor)
458458

459459

460+
def _adjust_hue_cvcuda(image: "cvcuda.Tensor", hue_factor: float) -> "cvcuda.Tensor":
461+
cvcuda = _import_cvcuda()
462+
463+
if not (-0.5 <= hue_factor <= 0.5):
464+
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
465+
466+
c = image.shape[3]
467+
if c not in [1, 3, 4]:
468+
raise TypeError(f"Input image tensor permitted channel values are 1, 3, or 4, but found {c}")
469+
470+
if c == 1: # Match PIL behaviour
471+
return image
472+
473+
# no native adjust_hue, use CV-CUDA for color converison, use torch for elementwise operations
474+
hsv = cvcuda.cvtcolor(image, cvcuda.ColorConversion.RGB2HSV)
475+
hsv_torch = torch.as_tensor(hsv.cuda()).float()
476+
hsv_torch[..., 0] = (hsv_torch[..., 0] + hue_factor * 180) % 180
477+
hsv_modified = cvcuda.as_tensor(hsv_torch.to(torch.uint8), "NHWC")
478+
return cvcuda.cvtcolor(hsv_modified, cvcuda.ColorConversion.HSV2RGB)
479+
480+
481+
if CVCUDA_AVAILABLE:
482+
_register_kernel_internal(adjust_hue, _import_cvcuda().Tensor)(_adjust_hue_cvcuda)
483+
484+
460485
def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
461486
"""Adjust gamma."""
462487
if torch.jit.is_scripting():

0 commit comments

Comments
 (0)