Skip to content

Commit e0392a0

Browse files
committed
wip adjust_saturation
1 parent 310982c commit e0392a0

File tree

1 file changed

+35
-0
lines changed
  • torchvision/transforms/v2/functional

1 file changed

+35
-0
lines changed

torchvision/transforms/v2/functional/_color.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,41 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
197197
return adjust_saturation_image(video, saturation_factor=saturation_factor)
198198

199199

200+
def _adjust_saturation_cvcuda(image: "cvcuda.Tensor", saturation_factor: float) -> "cvcuda.Tensor":
201+
if saturation_factor < 0:
202+
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
203+
204+
c = image.shape[-1] # NHWC layout
205+
if c not in [1, 3, 4]:
206+
raise TypeError(f"Input image tensor permitted channel values are 1, 3, or 4, but found {c}")
207+
208+
if c == 1: # Match PIL behaviour
209+
return image
210+
211+
# Grayscale weights (same as _rgb_to_grayscale_image)
212+
sf = saturation_factor
213+
r, g, b = 0.2989, 0.587, 0.114
214+
215+
# Build 3x4 saturation matrix
216+
twist_data = [
217+
[sf + (1 - sf) * r, (1 - sf) * g, (1 - sf) * b, 0.0],
218+
[(1 - sf) * r, sf + (1 - sf) * g, (1 - sf) * b, 0.0],
219+
[(1 - sf) * r, (1 - sf) * g, sf + (1 - sf) * b, 0.0],
220+
]
221+
twist = cvcuda.Tensor(
222+
torch.tensor(twist_data, dtype=torch.float32, device="cuda").contiguous(),
223+
layout="HW",
224+
)
225+
226+
return cvcuda.color_twist(image, twist)
227+
228+
229+
if CVCUDA_AVAILABLE:
230+
_register_kernel_internal(adjust_saturation, cvcuda.Tensor)(
231+
_adjust_saturation_cvcuda
232+
)
233+
234+
200235
def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
201236
"""See :class:`~torchvision.transforms.RandomAutocontrast`"""
202237
if torch.jit.is_scripting():

0 commit comments

Comments
 (0)