@@ -197,6 +197,41 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
197197 return adjust_saturation_image (video , saturation_factor = saturation_factor )
198198
199199
200+ def _adjust_saturation_cvcuda (image : "cvcuda.Tensor" , saturation_factor : float ) -> "cvcuda.Tensor" :
201+ if saturation_factor < 0 :
202+ raise ValueError (f"saturation_factor ({ saturation_factor } ) is not non-negative." )
203+
204+ c = image .shape [- 1 ] # NHWC layout
205+ if c not in [1 , 3 , 4 ]:
206+ raise TypeError (f"Input image tensor permitted channel values are 1, 3, or 4, but found { c } " )
207+
208+ if c == 1 : # Match PIL behaviour
209+ return image
210+
211+ # Grayscale weights (same as _rgb_to_grayscale_image)
212+ sf = saturation_factor
213+ r , g , b = 0.2989 , 0.587 , 0.114
214+
215+ # Build 3x4 saturation matrix
216+ twist_data = [
217+ [sf + (1 - sf ) * r , (1 - sf ) * g , (1 - sf ) * b , 0.0 ],
218+ [(1 - sf ) * r , sf + (1 - sf ) * g , (1 - sf ) * b , 0.0 ],
219+ [(1 - sf ) * r , (1 - sf ) * g , sf + (1 - sf ) * b , 0.0 ],
220+ ]
221+ twist = cvcuda .Tensor (
222+ torch .tensor (twist_data , dtype = torch .float32 , device = "cuda" ).contiguous (),
223+ layout = "HW" ,
224+ )
225+
226+ return cvcuda .color_twist (image , twist )
227+
228+
229+ if CVCUDA_AVAILABLE :
230+ _register_kernel_internal (adjust_saturation , cvcuda .Tensor )(
231+ _adjust_saturation_cvcuda
232+ )
233+
234+
200235def adjust_contrast (inpt : torch .Tensor , contrast_factor : float ) -> torch .Tensor :
201236 """See :class:`~torchvision.transforms.RandomAutocontrast`"""
202237 if torch .jit .is_scripting ():
0 commit comments