Skip to content

Commit 745d7c7

Browse files
committed
perspective verified
1 parent d6711d3 commit 745d7c7

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

test/test_transforms_v2.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5204,11 +5204,8 @@ def test_image_functional_correctness(self, coefficients, interpolation, fill, m
52045204
image, startpoints=None, endpoints=None, coefficients=coefficients, interpolation=interpolation, fill=fill
52055205
)
52065206
if make_input is make_image_cvcuda:
5207-
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
5208-
actual = actual.squeeze(0)
5209-
# drop the batch dimension
5210-
image = F.cvcuda_to_tensor(image).to(device="cpu")
5211-
image = image.squeeze(0)
5207+
actual = cvcuda_to_pil_compatible_tensor(actual)
5208+
image = cvcuda_to_pil_compatible_tensor(image)
52125209

52135210
expected = F.to_image(
52145211
F.perspective(
@@ -5234,7 +5231,7 @@ def test_image_functional_correctness(self, coefficients, interpolation, fill, m
52345231
# visually the results are the same on real images,
52355232
# realistically, the diff is not visible to the human eye
52365233
tolerance = 255 if interpolation is transforms.InterpolationMode.NEAREST else 125
5237-
torch.testing.assert_close(actual, expected, rtol=0, atol=tolerance)
5234+
assert_close(actual, expected, rtol=0, atol=tolerance)
52385235

52395236
def _reference_perspective_bounding_boxes(self, bounding_boxes, *, startpoints, endpoints):
52405237
format = bounding_boxes.format

torchvision/transforms/v2/_geometry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchvision.ops.boxes import box_iou
1212
from torchvision.transforms.functional import _get_perspective_coeffs
1313
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
14-
from torchvision.transforms.v2.functional._utils import _FillType
14+
from torchvision.transforms.v2.functional._utils import _FillType, is_cvcuda_tensor
1515

1616
from ._transform import _RandomApplyTransform
1717
from ._utils import (
@@ -936,6 +936,8 @@ class RandomPerspective(_RandomApplyTransform):
936936

937937
_v1_transform_cls = _transforms.RandomPerspective
938938

939+
_transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,)
940+
939941
def __init__(
940942
self,
941943
distortion_scale: float = 0.5,

0 commit comments

Comments
 (0)