Skip to content

Commit 61b237c

Browse files
committed
adjust saturation complete and tested
1 parent e0392a0 commit 61b237c

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

test/test_transforms_v2.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6277,7 +6277,18 @@ def test_kernel_image(self, dtype, device):
62776277
def test_kernel_video(self):
62786278
check_kernel(F.adjust_saturation_video, make_video(), saturation_factor=0.5)
62796279

6280-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
6280+
@pytest.mark.parametrize(
6281+
"make_input",
6282+
[
6283+
make_image_tensor,
6284+
make_image,
6285+
make_image_pil,
6286+
make_video,
6287+
pytest.param(
6288+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6289+
),
6290+
],
6291+
)
62816292
def test_functional(self, make_input):
62826293
check_functional(F.adjust_saturation, make_input(), saturation_factor=0.5)
62836294

@@ -6288,9 +6299,16 @@ def test_functional(self, make_input):
62886299
(F._color._adjust_saturation_image_pil, PIL.Image.Image),
62896300
(F.adjust_saturation_image, tv_tensors.Image),
62906301
(F.adjust_saturation_video, tv_tensors.Video),
6302+
pytest.param(
6303+
F._color._adjust_saturation_cvcuda,
6304+
"cvcuda.Tensor",
6305+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
6306+
),
62916307
],
62926308
)
62936309
def test_functional_signature(self, kernel, input_type):
6310+
if input_type == "cvcuda.Tensor":
6311+
input_type = _import_cvcuda().Tensor
62946312
check_functional_kernel_signature_match(F.adjust_saturation, kernel=kernel, input_type=input_type)
62956313

62966314
def test_functional_error(self):
@@ -6300,11 +6318,28 @@ def test_functional_error(self):
63006318
with pytest.raises(ValueError, match="is not non-negative"):
63016319
F.adjust_saturation(make_image(), saturation_factor=-1)
63026320

6321+
@pytest.mark.parametrize(
6322+
"make_input",
6323+
[
6324+
make_image,
6325+
pytest.param(
6326+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6327+
),
6328+
],
6329+
)
6330+
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
63036331
@pytest.mark.parametrize("saturation_factor", [0.1, 0.5, 1.0])
6304-
def test_correctness_image(self, saturation_factor):
6305-
image = make_image(dtype=torch.uint8, device="cpu")
6332+
def test_correctness_image(self, make_input, color_space, saturation_factor):
6333+
image = make_input(dtype=torch.uint8, color_space=color_space, device="cpu")
63066334

63076335
actual = F.adjust_saturation(image, saturation_factor=saturation_factor)
6336+
6337+
if make_input is make_image_cvcuda:
6338+
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
6339+
actual = actual.squeeze(0)
6340+
image = F.cvcuda_to_tensor(image)
6341+
image = image.squeeze(0)
6342+
63086343
expected = F.to_image(F.adjust_saturation(F.to_pil_image(image), saturation_factor=saturation_factor))
63096344

63106345
assert_close(actual, expected, rtol=0, atol=1)

torchvision/transforms/v2/functional/_color.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -201,35 +201,31 @@ def _adjust_saturation_cvcuda(image: "cvcuda.Tensor", saturation_factor: float)
201201
if saturation_factor < 0:
202202
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
203203

204-
c = image.shape[-1] # NHWC layout
204+
c = image.shape[3]
205205
if c not in [1, 3, 4]:
206206
raise TypeError(f"Input image tensor permitted channel values are 1, 3, or 4, but found {c}")
207207

208208
if c == 1: # Match PIL behaviour
209209
return image
210210

211-
# Grayscale weights (same as _rgb_to_grayscale_image)
211+
# grayscale weights
212212
sf = saturation_factor
213213
r, g, b = 0.2989, 0.587, 0.114
214-
215-
# Build 3x4 saturation matrix
216214
twist_data = [
217215
[sf + (1 - sf) * r, (1 - sf) * g, (1 - sf) * b, 0.0],
218216
[(1 - sf) * r, sf + (1 - sf) * g, (1 - sf) * b, 0.0],
219217
[(1 - sf) * r, (1 - sf) * g, sf + (1 - sf) * b, 0.0],
220218
]
221-
twist = cvcuda.Tensor(
222-
torch.tensor(twist_data, dtype=torch.float32, device="cuda").contiguous(),
223-
layout="HW",
219+
twist_tensor = cvcuda.as_tensor(
220+
torch.tensor(twist_data, dtype=torch.float32, device="cuda"),
221+
"HW",
224222
)
225223

226-
return cvcuda.color_twist(image, twist)
224+
return cvcuda.color_twist(image, twist_tensor)
227225

228226

229227
if CVCUDA_AVAILABLE:
230-
_register_kernel_internal(adjust_saturation, cvcuda.Tensor)(
231-
_adjust_saturation_cvcuda
232-
)
228+
_register_kernel_internal(adjust_saturation, cvcuda.Tensor)(_adjust_saturation_cvcuda)
233229

234230

235231
def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:

0 commit comments

Comments
 (0)