Skip to content

Commit d80fc3b

Browse files
committed
simplify PIL comparisions
1 parent 80dc7dd commit d80fc3b

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

test/common_utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2121
from torchvision import io, tv_tensors
2222
from torchvision.transforms._functional_tensor import _max_value as get_max_value
23-
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
23+
from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
24+
from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available
2425
from torchvision.utils import _Image_fromarray
2526

2627

2728
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
2829
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
2930
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
31+
CVCUDA_AVAILABLE = _is_cvcuda_available()
3032
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
3133
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
3234
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
@@ -275,6 +277,17 @@ def combinations_grid(**kwargs):
275277
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
276278

277279

280+
def cvcuda_to_pil_compatible_tensor(tensor):
281+
tensor = cvcuda_to_tensor(tensor)
282+
if tensor.ndim != 4:
283+
raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.")
284+
if tensor.shape[0] != 1:
285+
raise ValueError(
286+
f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}."
287+
)
288+
return tensor.squeeze(0).cpu()
289+
290+
278291
class ImagePair(TensorLikePair):
279292
def __init__(
280293
self,
@@ -287,6 +300,16 @@ def __init__(
287300
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
288301
actual, expected = (to_image(input) for input in [actual, expected])
289302

303+
if CVCUDA_AVAILABLE and all(isinstance(input, _import_cvcuda().Tensor) for input in [actual, expected]):
304+
actual, expected = (cvcuda_to_tensor(input) for input in [actual, expected])
305+
306+
if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor) and isinstance(expected, PIL.Image.Image):
307+
actual = cvcuda_to_pil_compatible_tensor(actual)
308+
expected = to_image(expected)
309+
310+
if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor):
311+
actual = cvcuda_to_pil_compatible_tensor(actual)
312+
290313
super().__init__(actual, expected, **other_parameters)
291314
self.mae = mae
292315

test/test_transforms_v2.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
assert_equal,
2525
cache,
2626
cpu_and_cuda,
27+
cvcuda_to_pil_compatible_tensor,
2728
freeze_rng_state,
2829
ignore_jit_no_profile_information_warning,
2930
make_bounding_boxes,
@@ -6427,11 +6428,7 @@ def test_image_correctness(self, num_output_channels, color_space, make_input, f
64276428
actual = fn(image, num_output_channels=num_output_channels)
64286429

64296430
if make_input is make_image_cvcuda:
6430-
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
6431-
actual = actual.squeeze(0)
6432-
# drop the batch dimension
6433-
image = F.cvcuda_to_tensor(image).to(device="cpu")
6434-
image = image.squeeze(0)
6431+
image = cvcuda_to_pil_compatible_tensor(image)
64356432

64366433
expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels))
64376434

@@ -6532,11 +6529,7 @@ def test_image_correctness(self, make_input, fn):
65326529
actual = fn(image)
65336530

65346531
if make_input is make_image_cvcuda:
6535-
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
6536-
actual = actual.squeeze(0)
6537-
# drop the batch dimension
6538-
image = F.cvcuda_to_tensor(image).to(device="cpu")
6539-
image = image.squeeze(0)
6532+
image = cvcuda_to_pil_compatible_tensor(image)
65406533

65416534
expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image)))
65426535

0 commit comments

Comments
 (0)