From c92bc32b10fca0d0e3e13a20b88dd9b2006b7ad1 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Wed, 1 Oct 2025 09:12:49 -0700 Subject: [PATCH 01/10] add `clamping_mode` parameter to `KeyPoints` constructor --- torchvision/tv_tensors/__init__.py | 4 ++-- torchvision/tv_tensors/_keypoints.py | 20 ++++++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index 744e5241135..631716a9370 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,6 +1,6 @@ import torch -from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format +from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format, CLAMPING_MODE_TYPE from ._image import Image from ._keypoints import KeyPoints from ._mask import Mask @@ -34,6 +34,6 @@ def wrap(wrappee, *, like, **kwargs): clamping_mode=kwargs.get("clamping_mode", like.clamping_mode), ) elif isinstance(like, KeyPoints): - return KeyPoints._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size)) + return KeyPoints._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size), clamping_mode=kwargs.get("clamping_mode", like.clamping_mode)) else: return wrappee.as_subclass(type(like)) diff --git a/torchvision/tv_tensors/_keypoints.py b/torchvision/tv_tensors/_keypoints.py index aede31ad7db..04262388c8a 100644 --- a/torchvision/tv_tensors/_keypoints.py +++ b/torchvision/tv_tensors/_keypoints.py @@ -6,6 +6,7 @@ from torch.utils._pytree import tree_flatten from ._tv_tensor import TVTensor +from ._bounding_boxes import CLAMPING_MODE_TYPE class KeyPoints(TVTensor): @@ -43,6 +44,8 @@ class KeyPoints(TVTensor): :func:`torch.as_tensor`. canvas_size (two-tuple of ints): Height and width of the corresponding image or video. + clamping_mode: The clamping mode to use when applying transforms that may result in key points + outside of the image. Possible values are: "soft", "hard", or ``None``. Read more in :ref:`clamping_mode_tuto`. dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from ``data``. device (torch.device, optional): Desired device of the bounding box. If @@ -55,16 +58,20 @@ class KeyPoints(TVTensor): """ canvas_size: tuple[int, int] + clamping_mode: CLAMPING_MODE_TYPE @classmethod - def _wrap(cls, tensor: torch.Tensor, *, canvas_size: tuple[int, int], check_dims: bool = True) -> KeyPoints: # type: ignore[override] + def _wrap(cls, tensor: torch.Tensor, *, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft", check_dims: bool = True) -> KeyPoints: # type: ignore[override] if check_dims: if tensor.ndim == 1: tensor = tensor.unsqueeze(0) elif tensor.shape[-1] != 2: raise ValueError(f"Expected a tensor of shape (..., 2), not {tensor.shape}") + if clamping_mode is not None and clamping_mode not in ("hard", "soft"): + raise ValueError(f"clamping_mode must be None, hard or soft, got {clamping_mode}.") points = tensor.as_subclass(cls) points.canvas_size = canvas_size + points.clamping_mode = clamping_mode return points def __new__( @@ -72,12 +79,13 @@ def __new__( data: Any, *, canvas_size: tuple[int, int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", dtype: torch.dtype | None = None, device: torch.device | str | int | None = None, requires_grad: bool | None = None, ) -> KeyPoints: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) - return cls._wrap(tensor, canvas_size=canvas_size) + return cls._wrap(tensor, canvas_size=canvas_size, clamping_mode=clamping_mode) @classmethod def _wrap_output( @@ -89,14 +97,14 @@ def _wrap_output( # Similar to BoundingBoxes._wrap_output(), see comment there. flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator] first_keypoints_from_args = next(x for x in flat_params if isinstance(x, KeyPoints)) - canvas_size = first_keypoints_from_args.canvas_size + canvas_size, clamping_mode = first_keypoints_from_args.canvas_size, first_keypoints_from_args.clamping_mode if isinstance(output, torch.Tensor) and not isinstance(output, KeyPoints): - output = KeyPoints._wrap(output, canvas_size=canvas_size, check_dims=False) + output = KeyPoints._wrap(output, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False) elif isinstance(output, (tuple, list)): # This branch exists for chunk() and unbind() - output = type(output)(KeyPoints._wrap(part, canvas_size=canvas_size, check_dims=False) for part in output) + output = type(output)(KeyPoints._wrap(part, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False) for part in output) return output def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(canvas_size=self.canvas_size) + return self._make_repr(canvas_size=self.canvas_size, clamping_mode=self.clamping_mode) From fe510804ec31ed11073505b7af2f46b1e041920d Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Wed, 1 Oct 2025 11:45:24 -0700 Subject: [PATCH 02/10] add `clamping_mode` parameter to `clamp_keypoints` functional and class --- torchvision/transforms/v2/_meta.py | 23 ++++--- .../transforms/v2/functional/_geometry.py | 63 +++++++++++-------- torchvision/transforms/v2/functional/_meta.py | 20 ++++-- 3 files changed, 67 insertions(+), 39 deletions(-) diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index 39f223f0398..955f3db46b3 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -2,7 +2,7 @@ from torchvision import tv_tensors from torchvision.transforms.v2 import functional as F, Transform -from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE +from torchvision.tv_tensors import CLAMPING_MODE_TYPE class ConvertBoundingBoxFormat(Transform): @@ -46,17 +46,26 @@ def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> t class ClampKeyPoints(Transform): """Clamp keypoints to their corresponding image dimensions. - The clamping is done according to the keypoints' ``canvas_size`` meta-data. + Args: + clamping_mode: Default is "auto" which relies on the input keypoint' + ``clamping_mode`` attribute. + The clamping is done according to the keypoints' ``canvas_size`` meta-data. + Read more in :ref:`clamping_mode_tuto` + for more details on how to use this transform. + """ + def __init__(self, clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto") -> None: + super().__init__() + self.clamping_mode = clamping_mode _transformed_types = (tv_tensors.KeyPoints,) def transform(self, inpt: tv_tensors.KeyPoints, params: dict[str, Any]) -> tv_tensors.KeyPoints: - return F.clamp_keypoints(inpt) # type: ignore[return-value] + return F.clamp_keypoints(inpt, clamping_mode=self.clamping_mode) # type: ignore[return-value] class SetClampingMode(Transform): - """Sets the ``clamping_mode`` attribute of the bounding boxes for future transforms. + """Sets the ``clamping_mode`` attribute of the bounding boxes and keypoints for future transforms. @@ -73,9 +82,9 @@ def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None: if self.clamping_mode not in (None, "soft", "hard"): raise ValueError(f"clamping_mode must be soft, hard or None, got {clamping_mode}") - _transformed_types = (tv_tensors.BoundingBoxes,) + _transformed_types = (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints) - def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes: - out: tv_tensors.BoundingBoxes = inpt.clone() # type: ignore[assignment] + def transform(self, inpt: tv_tensors.TVTensor, params: dict[str, Any]) -> tv_tensors.TVTensor: + out: tv_tensors.TVTensor = inpt.clone() # type: ignore[assignment] out.clamping_mode = self.clamping_mode return out diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 0c7eab0c04e..3ee8ee7511e 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -20,7 +20,7 @@ pil_to_tensor, to_pil_image, ) -from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE +from torchvision.tv_tensors import CLAMPING_MODE_TYPE from torchvision.utils import _log_api_usage_once @@ -67,16 +67,16 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: return horizontal_flip_image(mask) -def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]): +def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft"): shape = keypoints.shape keypoints = keypoints.clone().reshape(-1, 2) keypoints[..., 0] = keypoints[..., 0].sub_(canvas_size[1] - 1).neg_() - return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size) + return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size, clamping_mode=clamping_mode) @_register_kernel_internal(horizontal_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _horizontal_flip_keypoints_dispatch(keypoints: tv_tensors.KeyPoints): - out = horizontal_flip_keypoints(keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size) + out = horizontal_flip_keypoints(keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size, clamping_mode=keypoints.clamping_mode) return tv_tensors.wrap(out, like=keypoints) @@ -155,11 +155,11 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: return vertical_flip_image(mask) -def vertical_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> torch.Tensor: +def vertical_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft",) -> torch.Tensor: shape = keypoints.shape keypoints = keypoints.clone().reshape(-1, 2) keypoints[..., 1] = keypoints[..., 1].sub_(canvas_size[0] - 1).neg_() - return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size) + return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size, clamping_mode=clamping_mode) def vertical_flip_bounding_boxes( @@ -199,7 +199,7 @@ def vertical_flip_bounding_boxes( @_register_kernel_internal(vertical_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _vertical_flip_keypoints_dispatch(inpt: tv_tensors.KeyPoints) -> tv_tensors.KeyPoints: - output = vertical_flip_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size) + output = vertical_flip_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=inpt.clamping_mode) return tv_tensors.wrap(output, like=inpt) @@ -968,6 +968,7 @@ def _affine_keypoints_with_expand( shear: list[float], center: Optional[list[float]] = None, expand: bool = False, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: if keypoints.numel() == 0: return keypoints, canvas_size @@ -1026,7 +1027,7 @@ def _affine_keypoints_with_expand( new_width, new_height = _compute_affine_output_size(affine_vector, width, height) canvas_size = (new_height, new_width) - out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size).reshape(original_shape) + out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(original_shape) out_keypoints = out_keypoints.to(original_dtype) return out_keypoints, canvas_size @@ -1040,6 +1041,7 @@ def affine_keypoints( scale: float, shear: list[float], center: Optional[list[float]] = None, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ): return _affine_keypoints_with_expand( keypoints=keypoints, @@ -1050,6 +1052,7 @@ def affine_keypoints( shear=shear, center=center, expand=False, + clamping_mode=clamping_mode, ) @@ -1071,6 +1074,7 @@ def _affine_keypoints_dispatch( scale=scale, shear=shear, center=center, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -1393,6 +1397,7 @@ def rotate_keypoints( angle: float, expand: bool = False, center: Optional[list[float]] = None, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: return _affine_keypoints_with_expand( keypoints=keypoints, @@ -1403,6 +1408,7 @@ def rotate_keypoints( shear=[0.0, 0.0], center=center, expand=expand, + clamping_mode=clamping_mode, ) @@ -1411,7 +1417,7 @@ def _rotate_keypoints_dispatch( inpt: tv_tensors.KeyPoints, angle: float, expand: bool = False, center: Optional[list[float]] = None, **kwargs ) -> tv_tensors.KeyPoints: output, canvas_size = rotate_keypoints( - inpt, canvas_size=inpt.canvas_size, angle=angle, center=center, expand=expand + inpt, canvas_size=inpt.canvas_size, angle=angle, center=center, expand=expand, clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -1683,7 +1689,7 @@ def pad_mask( def pad_keypoints( - keypoints: torch.Tensor, canvas_size: tuple[int, int], padding: list[int], padding_mode: str = "constant" + keypoints: torch.Tensor, canvas_size: tuple[int, int], padding: list[int], padding_mode: str = "constant", clamping_mode: CLAMPING_MODE_TYPE = "soft" ): SUPPORTED_MODES = ["constant"] if padding_mode not in SUPPORTED_MODES: @@ -1695,20 +1701,21 @@ def pad_keypoints( left, right, top, bottom = _parse_pad_padding(padding) pad = torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device) canvas_size = (canvas_size[0] + top + bottom, canvas_size[1] + left + right) - return clamp_keypoints(keypoints + pad, canvas_size), canvas_size + return clamp_keypoints(keypoints + pad, canvas_size=canvas_size, clamping_mode=clamping_mode), canvas_size @_register_kernel_internal(pad, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _pad_keypoints_dispatch( - keypoints: tv_tensors.KeyPoints, padding: list[int], padding_mode: str = "constant", **kwargs + inpt: tv_tensors.KeyPoints, padding: list[int], padding_mode: str = "constant", **kwargs ) -> tv_tensors.KeyPoints: output, canvas_size = pad_keypoints( - keypoints.as_subclass(torch.Tensor), - canvas_size=keypoints.canvas_size, + inpt.as_subclass(torch.Tensor), + canvas_size=inpt.canvas_size, padding=padding, padding_mode=padding_mode, + clamping_mode=inpt.clamping_mode, ) - return tv_tensors.wrap(output, like=keypoints, canvas_size=canvas_size) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) def pad_bounding_boxes( @@ -1812,19 +1819,20 @@ def crop_keypoints( left: int, height: int, width: int, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: keypoints = keypoints - torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device) canvas_size = (height, width) - return clamp_keypoints(keypoints, canvas_size=canvas_size), canvas_size + return clamp_keypoints(keypoints, canvas_size=canvas_size, clamping_mode=clamping_mode), canvas_size @_register_kernel_internal(crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _crop_keypoints_dispatch( inpt: tv_tensors.KeyPoints, top: int, left: int, height: int, width: int ) -> tv_tensors.KeyPoints: - output, canvas_size = crop_keypoints(inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width) + output, canvas_size = crop_keypoints(inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, clamping_mode=inpt.clamping_mode) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -2024,6 +2032,7 @@ def perspective_keypoints( startpoints: Optional[list[list[int]]], endpoints: Optional[list[list[int]]], coefficients: Optional[list[float]] = None, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ): if keypoints.numel() == 0: return keypoints @@ -2047,7 +2056,7 @@ def perspective_keypoints( numer_points = torch.matmul(points, theta1.T) denom_points = torch.matmul(points, theta2.T) transformed_points = numer_points.div_(denom_points) - return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size).reshape(original_shape) + return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(original_shape) @_register_kernel_internal(perspective, tv_tensors.KeyPoints, tv_tensor_wrapper=False) @@ -2064,6 +2073,7 @@ def _perspective_keypoints_dispatch( startpoints=startpoints, endpoints=endpoints, coefficients=coefficients, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt) @@ -2344,7 +2354,7 @@ def _create_identity_grid(size: tuple[int, int], device: torch.device, dtype: to def elastic_keypoints( - keypoints: torch.Tensor, canvas_size: tuple[int, int], displacement: torch.Tensor + keypoints: torch.Tensor, canvas_size: tuple[int, int], displacement: torch.Tensor, clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> torch.Tensor: expected_shape = (1, canvas_size[0], canvas_size[1], 2) if not isinstance(displacement, torch.Tensor): @@ -2376,12 +2386,12 @@ def elastic_keypoints( t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype) transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) - return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size).reshape(original_shape) + return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(original_shape) @_register_kernel_internal(elastic, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _elastic_keypoints_dispatch(inpt: tv_tensors.KeyPoints, displacement: torch.Tensor, **kwargs): - output = elastic_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, displacement=displacement) + output = elastic_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, displacement=displacement, clamping_mode=inpt.clamping_mode) return tv_tensors.wrap(output, like=inpt) @@ -2578,16 +2588,16 @@ def _center_crop_image_pil(image: PIL.Image.Image, output_size: list[int]) -> PI return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width) -def center_crop_keypoints(inpt: torch.Tensor, canvas_size: tuple[int, int], output_size: list[int]): +def center_crop_keypoints(inpt: torch.Tensor, canvas_size: tuple[int, int], output_size: list[int], clamping_mode: CLAMPING_MODE_TYPE = "soft",): crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size) - return crop_keypoints(inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width) + return crop_keypoints(inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width, clamping_mode=clamping_mode) @_register_kernel_internal(center_crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _center_crop_keypoints_dispatch(inpt: tv_tensors.KeyPoints, output_size: list[int]) -> tv_tensors.KeyPoints: output, canvas_size = center_crop_keypoints( - inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, output_size=output_size + inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, output_size=output_size, clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -2745,8 +2755,9 @@ def resized_crop_keypoints( height: int, width: int, size: list[int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> tuple[torch.Tensor, tuple[int, int]]: - keypoints, canvas_size = crop_keypoints(keypoints, top, left, height, width) + keypoints, canvas_size = crop_keypoints(keypoints, top, left, height, width, clamping_mode=clamping_mode) return resize_keypoints(keypoints, size=size, canvas_size=canvas_size) @@ -2755,7 +2766,7 @@ def _resized_crop_keypoints_dispatch( inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: list[int], **kwargs ): output, canvas_size = resized_crop_keypoints( - inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size + inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size, clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 4568b39ab59..1774e2db982 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -5,7 +5,7 @@ from torchvision import tv_tensors from torchvision.transforms import _functional_pil as _FP from torchvision.tv_tensors import BoundingBoxFormat -from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE +from torchvision.tv_tensors import CLAMPING_MODE_TYPE from torchvision.utils import _log_api_usage_once @@ -653,7 +653,9 @@ def clamp_bounding_boxes( ) -def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> torch.Tensor: +def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE) -> torch.Tensor: + if clamping_mode is None or clamping_mode != "hard": + return keypoints.clone() dtype = keypoints.dtype keypoints = keypoints.clone() if keypoints.is_floating_point() else keypoints.float() # Note that max is canvas_size[i] - 1 and not can canvas_size[i] like for @@ -666,20 +668,26 @@ def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> t def clamp_keypoints( inpt: torch.Tensor, canvas_size: Optional[tuple[int, int]] = None, + clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto", ) -> torch.Tensor: """See :func:`~torchvision.transforms.v2.ClampKeyPoints` for details.""" if not torch.jit.is_scripting(): _log_api_usage_once(clamp_keypoints) + if clamping_mode is not None and clamping_mode not in ("soft", "hard", "auto"): + raise ValueError(f"clamping_mode must be soft, hard, auto or None, got {clamping_mode}") + if torch.jit.is_scripting() or is_pure_tensor(inpt): - if canvas_size is None: - raise ValueError("For pure tensor inputs, `canvas_size` has to be passed.") - return _clamp_keypoints(inpt, canvas_size=canvas_size) + if canvas_size is None or (clamping_mode is not None and clamping_mode == "auto"): + raise ValueError("For pure tensor inputs, `canvas_size` and `clamping_mode` have to be passed.") + return _clamp_keypoints(inpt, canvas_size=canvas_size, clamping_mode=clamping_mode) elif isinstance(inpt, tv_tensors.KeyPoints): if canvas_size is not None: raise ValueError("For keypoints tv_tensor inputs, `canvas_size` must not be passed.") - output = _clamp_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size) + if clamping_mode is None and clamping_mode == "auto": + clamping_mode = inpt.clamping_mode + output = _clamp_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=clamping_mode) return tv_tensors.wrap(output, like=inpt) else: raise TypeError(f"Input can either be a plain tensor or a keypoints tv_tensor, but got {type(inpt)} instead.") From 436899955b732255727e8925f3343bd4803da4d1 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Wed, 1 Oct 2025 11:45:24 -0700 Subject: [PATCH 03/10] fix tests --- test/common_utils.py | 4 +-- test/test_transforms_v2.py | 67 +++++++++++++++++++++++++++++++++----- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 74ad31fea72..6dbf34db9bb 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -400,10 +400,10 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): +def make_keypoints(canvas_size=DEFAULT_SIZE, *, clamping_mode="soft", num_points=4, dtype=None, device="cpu"): y = torch.randint(0, canvas_size[0], size=(num_points, 1), dtype=dtype, device=device) x = torch.randint(0, canvas_size[1], size=(num_points, 1), dtype=dtype, device=device) - return tv_tensors.KeyPoints(torch.cat((x, y), dim=-1), canvas_size=canvas_size) + return tv_tensors.KeyPoints(torch.cat((x, y), dim=-1), canvas_size=canvas_size, clamping_mode=clamping_mode) def make_bounding_boxes( diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f92f2a0bc67..a66624eba61 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -633,6 +633,7 @@ def affine_rotated_bounding_boxes(bounding_boxes): def reference_affine_keypoints_helper(keypoints, *, affine_matrix, new_canvas_size=None, clamp=True): canvas_size = new_canvas_size or keypoints.canvas_size + clamping_mode = keypoints.clamping_mode def affine_keypoints(keypoints): dtype = keypoints.dtype @@ -652,7 +653,7 @@ def affine_keypoints(keypoints): ) if clamp: - output = F.clamp_keypoints(output, canvas_size=canvas_size) + output = F.clamp_keypoints(output, canvas_size=canvas_size, clamping_mode=clamping_mode) else: dtype = output.dtype @@ -660,7 +661,7 @@ def affine_keypoints(keypoints): return tv_tensors.KeyPoints( torch.cat([affine_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape(keypoints.shape), - canvas_size=canvas_size, + canvas_size=canvas_size, clamping_mode=clamping_mode ) @@ -3309,7 +3310,6 @@ def test_functional(self, make_input): (F.elastic_image, tv_tensors.Image), (F.elastic_mask, tv_tensors.Mask), (F.elastic_video, tv_tensors.Video), - (F.elastic_keypoints, tv_tensors.KeyPoints), ], ) def test_functional_signature(self, kernel, input_type): @@ -5325,6 +5325,7 @@ def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, fo def _reference_perspective_keypoints(self, keypoints, *, startpoints, endpoints): canvas_size = keypoints.canvas_size + clamping_mode = keypoints.clamping_mode dtype = keypoints.dtype device = keypoints.device @@ -5364,6 +5365,7 @@ def perspective_keypoints(keypoints): return F.clamp_keypoints( output, canvas_size=canvas_size, + clamping_mode=clamping_mode ).to(dtype=dtype, device=device) return tv_tensors.KeyPoints( @@ -5371,6 +5373,7 @@ def perspective_keypoints(keypoints): keypoints.shape ), canvas_size=canvas_size, + clamping_mode=clamping_mode, ) @pytest.mark.parametrize(("startpoints", "endpoints"), START_END_POINTS) @@ -5733,32 +5736,80 @@ def test_error(self): class TestClampKeyPoints: + @pytest.mark.parametrize("clamping_mode", ("soft", "hard", None)) @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_kernel(self, dtype, device): - keypoints = make_keypoints(dtype=dtype, device=device) + def test_kernel(self, clamping_mode, dtype, device): + keypoints = make_keypoints(dtype=dtype, device=device, clamping_mode=clamping_mode) check_kernel( F.clamp_keypoints, keypoints, canvas_size=keypoints.canvas_size, + clamping_mode=clamping_mode, ) - def test_functional(self): - check_functional(F.clamp_keypoints, make_keypoints()) + @pytest.mark.parametrize("clamping_mode", ("soft", "hard", None)) + def test_functional(self, clamping_mode): + check_functional(F.clamp_keypoints, make_keypoints(clamping_mode=clamping_mode)) def test_errors(self): input_tv_tensor = make_keypoints() input_pure_tensor = input_tv_tensor.as_subclass(torch.Tensor) - with pytest.raises(ValueError, match="`canvas_size` has to be passed"): + with pytest.raises(ValueError, match="`canvas_size` and `clamping_mode` have to be passed."): F.clamp_keypoints(input_pure_tensor, canvas_size=None) with pytest.raises(ValueError, match="`canvas_size` must not be passed"): F.clamp_keypoints(input_tv_tensor, canvas_size=input_tv_tensor.canvas_size) + with pytest.raises(ValueError, match="clamping_mode must be soft,"): + F.clamp_keypoints(input_tv_tensor, clamping_mode="bad") + with pytest.raises(ValueError, match="clamping_mode must be soft,"): + transforms.ClampKeyPoints(clamping_mode="bad")(input_tv_tensor) def test_transform(self): check_transform(transforms.ClampKeyPoints(), make_keypoints()) + @pytest.mark.parametrize("constructor_clamping_mode", ("soft", "hard", None)) + @pytest.mark.parametrize("clamping_mode", ("soft", "hard", None, "auto")) + @pytest.mark.parametrize("pass_pure_tensor", (True, False)) + @pytest.mark.parametrize("fn", [F.clamp_keypoints, transform_cls_to_functional(transforms.ClampKeyPoints)]) + def test_clamping_mode(self, constructor_clamping_mode, clamping_mode, pass_pure_tensor, fn): + # This test checks 2 things: + # - That passing clamping_mode=None to the clamp_keypointss + # functional (or to the class) relies on the box's `.clamping_mode` + # attribute + # - That clamping happens when it should, and only when it should, i.e. + # when the clamping mode is not None. It doesn't validate the + # numerical results, only that clamping happened. For that, we create + # a keypoints with large coordinates (100) inside of a small 10x10 image. + + if pass_pure_tensor and fn is not F.clamp_keypoints: + # Only the functional supports pure tensors, not the class + return + if pass_pure_tensor and clamping_mode == "auto": + # cannot leave clamping_mode="auto" when passing pure tensor + return + + keypoints = tv_tensors.KeyPoints( + [[0, 100], [0, 100]],canvas_size=(10, 10), clamping_mode=constructor_clamping_mode + ) + expected_clamped_output = torch.tensor([[0, 9], [0, 9]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]]) + + if pass_pure_tensor: + out = fn( + keypoints.as_subclass(torch.Tensor), + canvas_size=keypoints.canvas_size, + clamping_mode=clamping_mode, + ) + else: + out = fn(keypoints, clamping_mode=clamping_mode) + + clamping_mode_prevailing = constructor_clamping_mode if clamping_mode == "auto" else clamping_mode + if clamping_mode_prevailing is None: + assert_equal(keypoints, out) # should be a pass-through + else: + assert_equal(out, expected_clamped_output) + class TestInvert: @pytest.mark.parametrize("dtype", [torch.uint8, torch.int16, torch.float32]) From e9e901dac11efa557e6c9aebb1eeb9f898ee94d8 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Wed, 1 Oct 2025 11:45:24 -0700 Subject: [PATCH 04/10] lint --- test/test_transforms_v2.py | 20 +++-- torchvision/transforms/v2/_meta.py | 15 ++-- .../transforms/v2/functional/_geometry.py | 83 +++++++++++++++---- torchvision/transforms/v2/functional/_meta.py | 11 ++- torchvision/tv_tensors/__init__.py | 2 +- torchvision/tv_tensors/_keypoints.py | 8 +- 6 files changed, 101 insertions(+), 38 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index a66624eba61..c98c16e0d4c 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -661,7 +661,8 @@ def affine_keypoints(keypoints): return tv_tensors.KeyPoints( torch.cat([affine_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape(keypoints.shape), - canvas_size=canvas_size, clamping_mode=clamping_mode + canvas_size=canvas_size, + clamping_mode=clamping_mode, ) @@ -5362,11 +5363,9 @@ def perspective_keypoints(keypoints): ) # It is important to clamp before casting, especially for CXCYWH format, dtype=int64 - return F.clamp_keypoints( - output, - canvas_size=canvas_size, - clamping_mode=clamping_mode - ).to(dtype=dtype, device=device) + return F.clamp_keypoints(output, canvas_size=canvas_size, clamping_mode=clamping_mode).to( + dtype=dtype, device=device + ) return tv_tensors.KeyPoints( torch.cat([perspective_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape( @@ -5791,9 +5790,14 @@ def test_clamping_mode(self, constructor_clamping_mode, clamping_mode, pass_pure return keypoints = tv_tensors.KeyPoints( - [[0, 100], [0, 100]],canvas_size=(10, 10), clamping_mode=constructor_clamping_mode + [[0, 100], [0, 100]], canvas_size=(10, 10), clamping_mode=constructor_clamping_mode + ) + expected_clamped_output = ( + torch.tensor([[0, 10], [0, 10]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]]) + ) + expected_clamped_output = ( + torch.tensor([[0, 9], [0, 9]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]]) ) - expected_clamped_output = torch.tensor([[0, 9], [0, 9]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]]) if pass_pure_tensor: out = fn( diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index 955f3db46b3..e0aad2f3899 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -46,14 +46,15 @@ def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> t class ClampKeyPoints(Transform): """Clamp keypoints to their corresponding image dimensions. - Args: - clamping_mode: Default is "auto" which relies on the input keypoint' - ``clamping_mode`` attribute. - The clamping is done according to the keypoints' ``canvas_size`` meta-data. - Read more in :ref:`clamping_mode_tuto` - for more details on how to use this transform. - + Args: + clamping_mode: Default is "auto" which relies on the input keypoint' + ``clamping_mode`` attribute. + The clamping is done according to the keypoints' ``canvas_size`` meta-data. + Read more in :ref:`clamping_mode_tuto` + for more details on how to use this transform. + """ + def __init__(self, clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto") -> None: super().__init__() self.clamping_mode = clamping_mode diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 3ee8ee7511e..0f1acdd887b 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -67,7 +67,9 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: return horizontal_flip_image(mask) -def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft"): +def horizontal_flip_keypoints( + keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft" +): shape = keypoints.shape keypoints = keypoints.clone().reshape(-1, 2) keypoints[..., 0] = keypoints[..., 0].sub_(canvas_size[1] - 1).neg_() @@ -76,7 +78,9 @@ def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, i @_register_kernel_internal(horizontal_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _horizontal_flip_keypoints_dispatch(keypoints: tv_tensors.KeyPoints): - out = horizontal_flip_keypoints(keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size, clamping_mode=keypoints.clamping_mode) + out = horizontal_flip_keypoints( + keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size, clamping_mode=keypoints.clamping_mode + ) return tv_tensors.wrap(out, like=keypoints) @@ -155,7 +159,11 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: return vertical_flip_image(mask) -def vertical_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft",) -> torch.Tensor: +def vertical_flip_keypoints( + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", +) -> torch.Tensor: shape = keypoints.shape keypoints = keypoints.clone().reshape(-1, 2) keypoints[..., 1] = keypoints[..., 1].sub_(canvas_size[0] - 1).neg_() @@ -199,7 +207,9 @@ def vertical_flip_bounding_boxes( @_register_kernel_internal(vertical_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _vertical_flip_keypoints_dispatch(inpt: tv_tensors.KeyPoints) -> tv_tensors.KeyPoints: - output = vertical_flip_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=inpt.clamping_mode) + output = vertical_flip_keypoints( + inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=inpt.clamping_mode + ) return tv_tensors.wrap(output, like=inpt) @@ -1027,7 +1037,9 @@ def _affine_keypoints_with_expand( new_width, new_height = _compute_affine_output_size(affine_vector, width, height) canvas_size = (new_height, new_width) - out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(original_shape) + out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size, clamping_mode=clamping_mode).reshape( + original_shape + ) out_keypoints = out_keypoints.to(original_dtype) return out_keypoints, canvas_size @@ -1417,7 +1429,12 @@ def _rotate_keypoints_dispatch( inpt: tv_tensors.KeyPoints, angle: float, expand: bool = False, center: Optional[list[float]] = None, **kwargs ) -> tv_tensors.KeyPoints: output, canvas_size = rotate_keypoints( - inpt, canvas_size=inpt.canvas_size, angle=angle, center=center, expand=expand, clamping_mode=inpt.clamping_mode, + inpt, + canvas_size=inpt.canvas_size, + angle=angle, + center=center, + expand=expand, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -1689,7 +1706,11 @@ def pad_mask( def pad_keypoints( - keypoints: torch.Tensor, canvas_size: tuple[int, int], padding: list[int], padding_mode: str = "constant", clamping_mode: CLAMPING_MODE_TYPE = "soft" + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + padding: list[int], + padding_mode: str = "constant", + clamping_mode: CLAMPING_MODE_TYPE = "soft", ): SUPPORTED_MODES = ["constant"] if padding_mode not in SUPPORTED_MODES: @@ -1832,7 +1853,9 @@ def crop_keypoints( def _crop_keypoints_dispatch( inpt: tv_tensors.KeyPoints, top: int, left: int, height: int, width: int ) -> tv_tensors.KeyPoints: - output, canvas_size = crop_keypoints(inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, clamping_mode=inpt.clamping_mode) + output, canvas_size = crop_keypoints( + inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, clamping_mode=inpt.clamping_mode + ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -2056,7 +2079,9 @@ def perspective_keypoints( numer_points = torch.matmul(points, theta1.T) denom_points = torch.matmul(points, theta2.T) transformed_points = numer_points.div_(denom_points) - return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(original_shape) + return clamp_keypoints( + transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode + ).reshape(original_shape) @_register_kernel_internal(perspective, tv_tensors.KeyPoints, tv_tensor_wrapper=False) @@ -2354,7 +2379,10 @@ def _create_identity_grid(size: tuple[int, int], device: torch.device, dtype: to def elastic_keypoints( - keypoints: torch.Tensor, canvas_size: tuple[int, int], displacement: torch.Tensor, clamping_mode: CLAMPING_MODE_TYPE = "soft", + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + displacement: torch.Tensor, + clamping_mode: CLAMPING_MODE_TYPE = "soft", ) -> torch.Tensor: expected_shape = (1, canvas_size[0], canvas_size[1], 2) if not isinstance(displacement, torch.Tensor): @@ -2386,12 +2414,19 @@ def elastic_keypoints( t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype) transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) - return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(original_shape) + return clamp_keypoints( + transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode + ).reshape(original_shape) @_register_kernel_internal(elastic, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _elastic_keypoints_dispatch(inpt: tv_tensors.KeyPoints, displacement: torch.Tensor, **kwargs): - output = elastic_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, displacement=displacement, clamping_mode=inpt.clamping_mode) + output = elastic_keypoints( + inpt.as_subclass(torch.Tensor), + canvas_size=inpt.canvas_size, + displacement=displacement, + clamping_mode=inpt.clamping_mode, + ) return tv_tensors.wrap(output, like=inpt) @@ -2588,16 +2623,26 @@ def _center_crop_image_pil(image: PIL.Image.Image, output_size: list[int]) -> PI return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width) -def center_crop_keypoints(inpt: torch.Tensor, canvas_size: tuple[int, int], output_size: list[int], clamping_mode: CLAMPING_MODE_TYPE = "soft",): +def center_crop_keypoints( + inpt: torch.Tensor, + canvas_size: tuple[int, int], + output_size: list[int], + clamping_mode: CLAMPING_MODE_TYPE = "soft", +): crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size) - return crop_keypoints(inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width, clamping_mode=clamping_mode) + return crop_keypoints( + inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width, clamping_mode=clamping_mode + ) @_register_kernel_internal(center_crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False) def _center_crop_keypoints_dispatch(inpt: tv_tensors.KeyPoints, output_size: list[int]) -> tv_tensors.KeyPoints: output, canvas_size = center_crop_keypoints( - inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, output_size=output_size, clamping_mode=inpt.clamping_mode, + inpt.as_subclass(torch.Tensor), + canvas_size=inpt.canvas_size, + output_size=output_size, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) @@ -2766,7 +2811,13 @@ def _resized_crop_keypoints_dispatch( inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: list[int], **kwargs ): output, canvas_size = resized_crop_keypoints( - inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size, clamping_mode=inpt.clamping_mode, + inpt.as_subclass(torch.Tensor), + top=top, + left=left, + height=height, + width=width, + size=size, + clamping_mode=inpt.clamping_mode, ) return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 1774e2db982..a0be77181b6 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -4,8 +4,7 @@ import torch from torchvision import tv_tensors from torchvision.transforms import _functional_pil as _FP -from torchvision.tv_tensors import BoundingBoxFormat -from torchvision.tv_tensors import CLAMPING_MODE_TYPE +from torchvision.tv_tensors import BoundingBoxFormat, CLAMPING_MODE_TYPE from torchvision.utils import _log_api_usage_once @@ -653,7 +652,9 @@ def clamp_bounding_boxes( ) -def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE) -> torch.Tensor: +def _clamp_keypoints( + keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE +) -> torch.Tensor: if clamping_mode is None or clamping_mode != "hard": return keypoints.clone() dtype = keypoints.dtype @@ -687,7 +688,9 @@ def clamp_keypoints( raise ValueError("For keypoints tv_tensor inputs, `canvas_size` must not be passed.") if clamping_mode is None and clamping_mode == "auto": clamping_mode = inpt.clamping_mode - output = _clamp_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=clamping_mode) + output = _clamp_keypoints( + inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=clamping_mode + ) return tv_tensors.wrap(output, like=inpt) else: raise TypeError(f"Input can either be a plain tensor or a keypoints tv_tensor, but got {type(inpt)} instead.") diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index 631716a9370..e86b22e5cf3 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,6 +1,6 @@ import torch -from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format, CLAMPING_MODE_TYPE +from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, CLAMPING_MODE_TYPE, is_rotated_bounding_format from ._image import Image from ._keypoints import KeyPoints from ._mask import Mask diff --git a/torchvision/tv_tensors/_keypoints.py b/torchvision/tv_tensors/_keypoints.py index 04262388c8a..51633031bf6 100644 --- a/torchvision/tv_tensors/_keypoints.py +++ b/torchvision/tv_tensors/_keypoints.py @@ -5,9 +5,10 @@ import torch from torch.utils._pytree import tree_flatten -from ._tv_tensor import TVTensor from ._bounding_boxes import CLAMPING_MODE_TYPE +from ._tv_tensor import TVTensor + class KeyPoints(TVTensor): """:class:`torch.Tensor` subclass for tensors with shape ``[..., 2]`` that represent points in an image. @@ -103,7 +104,10 @@ def _wrap_output( output = KeyPoints._wrap(output, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False) elif isinstance(output, (tuple, list)): # This branch exists for chunk() and unbind() - output = type(output)(KeyPoints._wrap(part, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False) for part in output) + output = type(output)( + KeyPoints._wrap(part, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False) + for part in output + ) return output def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] From 4db3c2ecf2d39412ddbc965bce2faa9406c828aa Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Wed, 1 Oct 2025 12:07:55 -0700 Subject: [PATCH 05/10] lint --- torchvision/tv_tensors/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index e86b22e5cf3..1e6f12fb7f7 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -34,6 +34,10 @@ def wrap(wrappee, *, like, **kwargs): clamping_mode=kwargs.get("clamping_mode", like.clamping_mode), ) elif isinstance(like, KeyPoints): - return KeyPoints._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size), clamping_mode=kwargs.get("clamping_mode", like.clamping_mode)) + return KeyPoints._wrap( + wrappee, + canvas_size=kwargs.get("canvas_size", like.canvas_size), + clamping_mode=kwargs.get("clamping_mode", like.clamping_mode), + ) else: return wrappee.as_subclass(type(like)) From b22688ada8b54ea692462baa323377f46c20ea0c Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Wed, 1 Oct 2025 12:20:44 -0700 Subject: [PATCH 06/10] fix linting --- torchvision/transforms/v2/_meta.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index e0aad2f3899..a217eaa42af 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -85,7 +85,9 @@ def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None: _transformed_types = (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints) - def transform(self, inpt: tv_tensors.TVTensor, params: dict[str, Any]) -> tv_tensors.TVTensor: + def transform( + self, inpt: tv_tensors.BoundingBoxes | tv_tensors.KeyPoints, params: dict[str, Any] + ) -> tv_tensors.BoundingBoxes | tv_tensors.KeyPoints: out: tv_tensors.TVTensor = inpt.clone() # type: ignore[assignment] out.clamping_mode = self.clamping_mode return out From 2e2ea83d0df66e44ff5e2ade568f611b3ed3f536 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Wed, 1 Oct 2025 12:34:50 -0700 Subject: [PATCH 07/10] fix linting --- torchvision/transforms/v2/_meta.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index a217eaa42af..ac621df348b 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -85,9 +85,7 @@ def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None: _transformed_types = (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints) - def transform( - self, inpt: tv_tensors.BoundingBoxes | tv_tensors.KeyPoints, params: dict[str, Any] - ) -> tv_tensors.BoundingBoxes | tv_tensors.KeyPoints: + def transform(self, inpt: tv_tensors.TVTensor, params: dict[str, Any]) -> tv_tensors.TVTensor: out: tv_tensors.TVTensor = inpt.clone() # type: ignore[assignment] - out.clamping_mode = self.clamping_mode + out.clamping_mode = self.clamping_mode # type: ignore[assignment] return out From 6f5a3b7a6ad77756b5348e9e51210847f3afd3d6 Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Wed, 1 Oct 2025 12:45:33 -0700 Subject: [PATCH 08/10] fix linting --- torchvision/transforms/v2/_meta.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index ac621df348b..c23da1a36bc 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -85,7 +85,8 @@ def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None: _transformed_types = (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints) - def transform(self, inpt: tv_tensors.TVTensor, params: dict[str, Any]) -> tv_tensors.TVTensor: - out: tv_tensors.TVTensor = inpt.clone() # type: ignore[assignment] - out.clamping_mode = self.clamping_mode # type: ignore[assignment] + def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes: + # this method works for both `tv_tensors.BoundingBoxes`` and `tv_tensors.KeyPoints`. + out: tv_tensors.BoundingBoxes = inpt.clone() # type: ignore[assignment] + out.clamping_mode = self.clamping_mode return out From 37d73d4d5db8eef76bfdd3556714bd6d5a949faf Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Wed, 1 Oct 2025 14:05:08 -0700 Subject: [PATCH 09/10] fix tests --- test/test_transforms_v2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index c98c16e0d4c..b5a8e4788c3 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2086,7 +2086,6 @@ def test_functional(self, make_input): (F.rotate_image, tv_tensors.Image), (F.rotate_mask, tv_tensors.Mask), (F.rotate_video, tv_tensors.Video), - (F.rotate_keypoints, tv_tensors.KeyPoints), ], ) def test_functional_signature(self, kernel, input_type): @@ -4415,7 +4414,6 @@ def test_functional(self, make_input): (F.resized_crop_image, tv_tensors.Image), (F.resized_crop_mask, tv_tensors.Mask), (F.resized_crop_video, tv_tensors.Video), - (F.resized_crop_keypoints, tv_tensors.KeyPoints), ], ) def test_functional_signature(self, kernel, input_type): From 5f7d425693a89ab1cf96f6ecc1c410d4963890bf Mon Sep 17 00:00:00 2001 From: Antoine Simoulin Date: Thu, 9 Oct 2025 08:16:01 -0700 Subject: [PATCH 10/10] remove redundat lines in tests --- test/test_transforms_v2.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index b5a8e4788c3..574707fd4f2 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5790,9 +5790,6 @@ def test_clamping_mode(self, constructor_clamping_mode, clamping_mode, pass_pure keypoints = tv_tensors.KeyPoints( [[0, 100], [0, 100]], canvas_size=(10, 10), clamping_mode=constructor_clamping_mode ) - expected_clamped_output = ( - torch.tensor([[0, 10], [0, 10]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]]) - ) expected_clamped_output = ( torch.tensor([[0, 9], [0, 9]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]]) )